diff --git a/app/PersonEntity b/app/PersonEntity deleted file mode 100644 index e69de29..0000000 diff --git a/app/build.gradle.kts b/app/build.gradle.kts index e2d8e18..a013653 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -95,6 +95,5 @@ dependencies { // Workers implementation(libs.androidx.work.runtime.ktx) implementation(libs.androidx.hilt.work) - - + ksp(libs.androidx.hilt.compiler) } \ No newline at end of file diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml index ce43118..428c235 100644 --- a/app/src/main/AndroidManifest.xml +++ b/app/src/main/AndroidManifest.xml @@ -3,27 +3,33 @@ xmlns:tools="http://schemas.android.com/tools"> + android:theme="@style/Theme.SherpAI2"> + + + + + + android:exported="true"> - + \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/AppDatabase.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/AppDatabase.kt index 5773ab3..a128c5c 100644 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/AppDatabase.kt +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/AppDatabase.kt @@ -2,32 +2,32 @@ package com.placeholder.sherpai2.data.local import androidx.room.Database import androidx.room.RoomDatabase +import androidx.sqlite.db.SupportSQLiteDatabase +import androidx.room.migration.Migration import com.placeholder.sherpai2.data.local.dao.* import com.placeholder.sherpai2.data.local.entity.* /** * AppDatabase - Complete database for SherpAI2 * - * VERSION 7 - Added face detection cache to ImageEntity: - * - hasFaces: Boolean? - * - faceCount: Int? - * - facesLastDetected: Long? - * - faceDetectionVersion: Int? + * VERSION 10 - User Feedback Loop + * - Added UserFeedbackEntity for storing user corrections + * - Enables cluster refinement before training + * - Ground truth data for improving clustering * - * ENTITIES: - * - YOUR EXISTING: Image, Tag, Event, junction tables - * - NEW: PersonEntity (people in your app) - * - NEW: FaceModelEntity (face embeddings, links to PersonEntity) - * - NEW: PhotoFaceTagEntity (face detections, links to ImageEntity + FaceModelEntity) + * VERSION 9 - Enhanced Face Cache + * - Added FaceCacheEntity for per-face metadata + * - Stores quality scores, embeddings, bounding boxes + * - Enables intelligent face filtering for clustering * - * DEV MODE: Using destructive migration (fallbackToDestructiveMigration) - * - Fresh install on every schema change - * - No manual migrations needed during development + * VERSION 8 - PHASE 2: Multi-centroid face models + age tagging + * - Added PersonEntity.isChild, siblingIds, familyGroupId + * - Changed FaceModelEntity.embedding → centroidsJson (multi-centroid) + * - Added PersonAgeTagEntity table for searchable age tags * - * PRODUCTION MODE: Add proper migrations before release - * - See DatabaseMigration.kt for migration code - * - Remove fallbackToDestructiveMigration() - * - Add .addMigrations(MIGRATION_6_7) + * MIGRATION STRATEGY: + * - Development: fallbackToDestructiveMigration (fresh install) + * - Production: Add migrations before release */ @Database( entities = [ @@ -42,16 +42,18 @@ import com.placeholder.sherpai2.data.local.entity.* PersonEntity::class, FaceModelEntity::class, PhotoFaceTagEntity::class, + PersonAgeTagEntity::class, + FaceCacheEntity::class, + UserFeedbackEntity::class, // NEW: User corrections // ===== COLLECTIONS ===== CollectionEntity::class, CollectionImageEntity::class, CollectionFilterEntity::class ], - version = 7, // INCREMENTED for face detection cache + version = 10, // INCREMENTED for user feedback exportSchema = false ) -// No TypeConverters needed - embeddings stored as strings abstract class AppDatabase : RoomDatabase() { // ===== CORE DAOs ===== @@ -66,33 +68,187 @@ abstract class AppDatabase : RoomDatabase() { abstract fun personDao(): PersonDao abstract fun faceModelDao(): FaceModelDao abstract fun photoFaceTagDao(): PhotoFaceTagDao + abstract fun personAgeTagDao(): PersonAgeTagDao + abstract fun faceCacheDao(): FaceCacheDao + abstract fun userFeedbackDao(): UserFeedbackDao // NEW // ===== COLLECTIONS DAO ===== abstract fun collectionDao(): CollectionDao } /** - * MIGRATION NOTES FOR PRODUCTION: + * MIGRATION 7 → 8 (Phase 2) * - * When ready to ship to users, replace destructive migration with proper migration: + * Changes: + * 1. Add isChild, siblingIds, familyGroupId to persons table + * 2. Rename embedding → centroidsJson in face_models table + * 3. Create person_age_tags table + */ +val MIGRATION_7_8 = object : Migration(7, 8) { + override fun migrate(database: SupportSQLiteDatabase) { + + // ===== STEP 1: Update persons table ===== + database.execSQL("ALTER TABLE persons ADD COLUMN isChild INTEGER NOT NULL DEFAULT 0") + database.execSQL("ALTER TABLE persons ADD COLUMN siblingIds TEXT DEFAULT NULL") + database.execSQL("ALTER TABLE persons ADD COLUMN familyGroupId TEXT DEFAULT NULL") + + // Create index on familyGroupId for sibling queries + database.execSQL("CREATE INDEX IF NOT EXISTS index_persons_familyGroupId ON persons(familyGroupId)") + + // ===== STEP 2: Update face_models table ===== + // Rename embedding column to centroidsJson + // SQLite doesn't support RENAME COLUMN directly, so we need to: + // 1. Create new table with new schema + // 2. Copy data (converting single embedding to centroid JSON) + // 3. Drop old table + // 4. Rename new table + + // Create new table + database.execSQL(""" + CREATE TABLE IF NOT EXISTS face_models_new ( + id TEXT PRIMARY KEY NOT NULL, + personId TEXT NOT NULL, + centroidsJson TEXT NOT NULL, + trainingImageCount INTEGER NOT NULL, + averageConfidence REAL NOT NULL, + createdAt INTEGER NOT NULL, + updatedAt INTEGER NOT NULL, + lastUsed INTEGER, + isActive INTEGER NOT NULL, + FOREIGN KEY(personId) REFERENCES persons(id) ON DELETE CASCADE + ) + """) + + // Copy data, converting embedding to centroidsJson format + // This converts single embedding to a list with one centroid + database.execSQL(""" + INSERT INTO face_models_new + SELECT + id, + personId, + '[{"embedding":' || REPLACE(REPLACE(embedding, ',', ','), ',', ',') || ',"effectiveTimestamp":' || createdAt || ',"ageAtCapture":null,"photoCount":' || trainingImageCount || ',"timeRangeMonths":12,"avgConfidence":' || averageConfidence || '}]' as centroidsJson, + trainingImageCount, + averageConfidence, + createdAt, + updatedAt, + lastUsed, + isActive + FROM face_models + """) + + // Drop old table + database.execSQL("DROP TABLE face_models") + + // Rename new table + database.execSQL("ALTER TABLE face_models_new RENAME TO face_models") + + // Recreate index + database.execSQL("CREATE UNIQUE INDEX IF NOT EXISTS index_face_models_personId ON face_models(personId)") + + // ===== STEP 3: Create person_age_tags table ===== + database.execSQL(""" + CREATE TABLE IF NOT EXISTS person_age_tags ( + id TEXT PRIMARY KEY NOT NULL, + personId TEXT NOT NULL, + imageId TEXT NOT NULL, + ageAtCapture INTEGER NOT NULL, + tagValue TEXT NOT NULL, + confidence REAL NOT NULL, + createdAt INTEGER NOT NULL, + FOREIGN KEY(personId) REFERENCES persons(id) ON DELETE CASCADE, + FOREIGN KEY(imageId) REFERENCES images(imageId) ON DELETE CASCADE + ) + """) + + // Create indices for fast lookups + database.execSQL("CREATE INDEX IF NOT EXISTS index_person_age_tags_personId ON person_age_tags(personId)") + database.execSQL("CREATE INDEX IF NOT EXISTS index_person_age_tags_imageId ON person_age_tags(imageId)") + database.execSQL("CREATE INDEX IF NOT EXISTS index_person_age_tags_ageAtCapture ON person_age_tags(ageAtCapture)") + database.execSQL("CREATE INDEX IF NOT EXISTS index_person_age_tags_tagValue ON person_age_tags(tagValue)") + } +} + +/** + * MIGRATION 8 → 9 (Enhanced Face Cache) * - * val MIGRATION_6_7 = object : Migration(6, 7) { - * override fun migrate(database: SupportSQLiteDatabase) { - * // Add face detection cache columns - * database.execSQL("ALTER TABLE images ADD COLUMN hasFaces INTEGER DEFAULT NULL") - * database.execSQL("ALTER TABLE images ADD COLUMN faceCount INTEGER DEFAULT NULL") - * database.execSQL("ALTER TABLE images ADD COLUMN facesLastDetected INTEGER DEFAULT NULL") - * database.execSQL("ALTER TABLE images ADD COLUMN faceDetectionVersion INTEGER DEFAULT NULL") + * Changes: + * 1. Create face_cache table for per-face metadata + */ +val MIGRATION_8_9 = object : Migration(8, 9) { + override fun migrate(database: SupportSQLiteDatabase) { + + // Create face_cache table + database.execSQL(""" + CREATE TABLE IF NOT EXISTS face_cache ( + imageId TEXT NOT NULL, + faceIndex INTEGER NOT NULL, + boundingBox TEXT NOT NULL, + faceWidth INTEGER NOT NULL, + faceHeight INTEGER NOT NULL, + faceAreaRatio REAL NOT NULL, + qualityScore REAL NOT NULL, + isLargeEnough INTEGER NOT NULL, + isFrontal INTEGER NOT NULL, + hasGoodLighting INTEGER NOT NULL, + embedding TEXT, + confidence REAL NOT NULL, + imageWidth INTEGER NOT NULL DEFAULT 0, + imageHeight INTEGER NOT NULL DEFAULT 0, + cacheVersion INTEGER NOT NULL DEFAULT 1, + cachedAt INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY(imageId, faceIndex), + FOREIGN KEY(imageId) REFERENCES images(imageId) ON DELETE CASCADE + ) + """) + + // Create indices for fast queries + database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_imageId ON face_cache(imageId)") + database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_qualityScore ON face_cache(qualityScore)") + database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_isLargeEnough ON face_cache(isLargeEnough)") + } +} + +/** + * MIGRATION 9 → 10 (User Feedback Loop) * - * // Create indices - * database.execSQL("CREATE INDEX IF NOT EXISTS index_images_hasFaces ON images(hasFaces)") - * database.execSQL("CREATE INDEX IF NOT EXISTS index_images_faceCount ON images(faceCount)") - * } - * } + * Changes: + * 1. Create user_feedback table for storing user corrections + */ +val MIGRATION_9_10 = object : Migration(9, 10) { + override fun migrate(database: SupportSQLiteDatabase) { + + // Create user_feedback table + database.execSQL(""" + CREATE TABLE IF NOT EXISTS user_feedback ( + id TEXT PRIMARY KEY NOT NULL, + imageId TEXT NOT NULL, + faceIndex INTEGER NOT NULL, + clusterId INTEGER, + personId TEXT, + feedbackType TEXT NOT NULL, + originalConfidence REAL NOT NULL, + userNote TEXT, + timestamp INTEGER NOT NULL, + FOREIGN KEY(imageId) REFERENCES images(imageId) ON DELETE CASCADE, + FOREIGN KEY(personId) REFERENCES persons(id) ON DELETE CASCADE + ) + """) + + // Create indices for fast lookups + database.execSQL("CREATE INDEX IF NOT EXISTS index_user_feedback_imageId ON user_feedback(imageId)") + database.execSQL("CREATE INDEX IF NOT EXISTS index_user_feedback_clusterId ON user_feedback(clusterId)") + database.execSQL("CREATE INDEX IF NOT EXISTS index_user_feedback_personId ON user_feedback(personId)") + database.execSQL("CREATE INDEX IF NOT EXISTS index_user_feedback_feedbackType ON user_feedback(feedbackType)") + } +} + +/** + * PRODUCTION MIGRATION NOTES: * - * Then in your database builder: - * Room.databaseBuilder(context, AppDatabase::class.java, "database_name") - * .addMigrations(MIGRATION_6_7) // Add this + * Before shipping to users, update DatabaseModule to use migrations: + * + * Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db") + * .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10) // Add all migrations * // .fallbackToDestructiveMigration() // Remove this * .build() */ \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Collectiondao.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Collectiondao.kt index 8541231..c496ecb 100644 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Collectiondao.kt +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Collectiondao.kt @@ -6,39 +6,71 @@ import com.placeholder.sherpai2.data.local.model.CollectionWithDetails import kotlinx.coroutines.flow.Flow /** - * CollectionDao - Manage user collections + * CollectionDao - Data Access Object for managing user-defined and system-generated collections. + * * Provides an interface for CRUD operations on the 'collections' table and manages the + * many-to-many relationships between collections and images using junction tables. */ @Dao interface CollectionDao { - // ========================================== - // BASIC OPERATIONS - // ========================================== + // ========================================================================================= + // BASIC CRUD OPERATIONS + // ========================================================================================= + /** + * Persists a new collection entity. + * @param collection The entity to be inserted. + * @return The row ID of the newly inserted item. + * Strategy: REPLACE ensures that if a collection with the same ID exists, it is overwritten. + */ @Insert(onConflict = OnConflictStrategy.REPLACE) suspend fun insert(collection: CollectionEntity): Long + /** + * Updates an existing collection based on its primary key. + * @param collection The entity containing updated fields. + */ @Update suspend fun update(collection: CollectionEntity) + /** + * Removes a specific collection entity from the database. + * @param collection The entity object to be deleted. + */ @Delete suspend fun delete(collection: CollectionEntity) + /** + * Deletes a collection entry directly by its unique string identifier. + * @param collectionId The unique ID of the collection to remove. + */ @Query("DELETE FROM collections WHERE collectionId = :collectionId") suspend fun deleteById(collectionId: String) + /** + * One-shot fetch for a specific collection. + * @param collectionId The unique ID of the collection. + * @return The CollectionEntity if found, null otherwise. + */ @Query("SELECT * FROM collections WHERE collectionId = :collectionId") suspend fun getById(collectionId: String): CollectionEntity? + /** + * Reactive stream for a specific collection. + * @param collectionId The unique ID of the collection. + * @return A Flow that emits the CollectionEntity whenever that specific row changes. + */ @Query("SELECT * FROM collections WHERE collectionId = :collectionId") fun getByIdFlow(collectionId: String): Flow - // ========================================== - // LIST QUERIES - // ========================================== + // ========================================================================================= + // LIST QUERIES (Observables) + // ========================================================================================= /** - * Get all collections ordered by pinned, then by creation date + * Retrieves all collections for the main UI list. + * Ordering: Prioritizes 'pinned' items first, then sorts by newest creation date. + * @return A Flow emitting a list of collections, updating automatically on table changes. */ @Query(""" SELECT * FROM collections @@ -46,6 +78,11 @@ interface CollectionDao { """) fun getAllCollections(): Flow> + /** + * Retrieves collections filtered by their type (e.g., SMART, STATIC, FAVORITE). + * @param type The category string to filter by. + * @return A Flow emitting the filtered list. + */ @Query(""" SELECT * FROM collections WHERE type = :type @@ -53,15 +90,22 @@ interface CollectionDao { """) fun getCollectionsByType(type: String): Flow> + /** + * Retrieves the single system-defined Favorite collection. + * Used for quick access to the standard 'Likes' functionality. + */ @Query("SELECT * FROM collections WHERE type = 'FAVORITE' LIMIT 1") suspend fun getFavoriteCollection(): CollectionEntity? - // ========================================== - // COLLECTION WITH DETAILS - // ========================================== + // ========================================================================================= + // COMPLEX RELATIONSHIPS & DATA MODELS + // ========================================================================================= /** - * Get collection with actual photo count + * Retrieves a specialized model [CollectionWithDetails] which includes the base collection + * data plus a dynamically calculated photo count from the junction table. + * * @Transaction is required here because the query involves a subquery/multiple operations + * to ensure data consistency across the result set. */ @Transaction @Query(""" @@ -75,25 +119,42 @@ interface CollectionDao { """) fun getCollectionWithDetails(collectionId: String): Flow - // ========================================== - // IMAGE MANAGEMENT - // ========================================== + // ========================================================================================= + // IMAGE MANAGEMENT (Junction Table: collection_images) + // ========================================================================================= + /** + * Maps an image to a collection in the junction table. + */ @Insert(onConflict = OnConflictStrategy.REPLACE) suspend fun addImage(collectionImage: CollectionImageEntity) + /** + * Batch maps multiple images to a collection. Useful for bulk imports or multi-selection. + */ @Insert(onConflict = OnConflictStrategy.REPLACE) suspend fun addImages(collectionImages: List) + /** + * Removes a specific image from a specific collection. + * Note: This does not delete the image from the 'images' table, only the relationship. + */ @Query(""" DELETE FROM collection_images WHERE collectionId = :collectionId AND imageId = :imageId """) suspend fun removeImage(collectionId: String, imageId: String) + /** + * Clears all image associations for a specific collection. + */ @Query("DELETE FROM collection_images WHERE collectionId = :collectionId") suspend fun clearAllImages(collectionId: String) + /** + * Performs a JOIN to retrieve actual ImageEntity objects associated with a collection. + * Ordered by the user's custom sort order, then by the time the image was added. + */ @Query(""" SELECT i.* FROM images i JOIN collection_images ci ON i.imageId = ci.imageId @@ -102,6 +163,9 @@ interface CollectionDao { """) fun getImagesInCollection(collectionId: String): Flow> + /** + * Fetches the top 4 images for a collection to be used as UI thumbnails/previews. + */ @Query(""" SELECT i.* FROM images i JOIN collection_images ci ON i.imageId = ci.imageId @@ -111,12 +175,19 @@ interface CollectionDao { """) suspend fun getPreviewImages(collectionId: String): List + /** + * Returns the current number of images associated with a collection. + */ @Query(""" SELECT COUNT(*) FROM collection_images WHERE collectionId = :collectionId """) suspend fun getPhotoCount(collectionId: String): Int + /** + * Checks if a specific image is already present in a collection. + * Returns true if a record exists. + */ @Query(""" SELECT EXISTS( SELECT 1 FROM collection_images @@ -125,19 +196,31 @@ interface CollectionDao { """) suspend fun containsImage(collectionId: String, imageId: String): Boolean - // ========================================== - // FILTER MANAGEMENT (for SMART collections) - // ========================================== + // ========================================================================================= + // FILTER MANAGEMENT (For Smart/Dynamic Collections) + // ========================================================================================= + /** + * Inserts a filter criteria for a Smart Collection. + */ @Insert(onConflict = OnConflictStrategy.REPLACE) suspend fun insertFilter(filter: CollectionFilterEntity) + /** + * Batch inserts multiple filter criteria. + */ @Insert(onConflict = OnConflictStrategy.REPLACE) suspend fun insertFilters(filters: List) + /** + * Removes all dynamic filter rules for a collection. + */ @Query("DELETE FROM collection_filters WHERE collectionId = :collectionId") suspend fun clearFilters(collectionId: String) + /** + * Retrieves the list of rules used to populate a Smart Collection. + */ @Query(""" SELECT * FROM collection_filters WHERE collectionId = :collectionId @@ -145,6 +228,9 @@ interface CollectionDao { """) suspend fun getFilters(collectionId: String): List + /** + * Observable stream of filters for a Smart Collection. + */ @Query(""" SELECT * FROM collection_filters WHERE collectionId = :collectionId @@ -152,30 +238,39 @@ interface CollectionDao { """) fun getFiltersFlow(collectionId: String): Flow> - // ========================================== - // STATISTICS - // ========================================== + // ========================================================================================= + // AGGREGATE STATISTICS + // ========================================================================================= + /** Total number of collections defined. */ @Query("SELECT COUNT(*) FROM collections") suspend fun getCollectionCount(): Int + /** Count of collections that update dynamically based on filters. */ @Query("SELECT COUNT(*) FROM collections WHERE type = 'SMART'") suspend fun getSmartCollectionCount(): Int + /** Count of manually curated collections. */ @Query("SELECT COUNT(*) FROM collections WHERE type = 'STATIC'") suspend fun getStaticCollectionCount(): Int + /** + * Returns the sum of the photoCount cache across all collections. + * Returns nullable Int in case the table is empty. + */ @Query(""" SELECT SUM(photoCount) FROM collections """) suspend fun getTotalPhotosInCollections(): Int? - // ========================================== - // UPDATES - // ========================================== + // ========================================================================================= + // GRANULAR UPDATES (Optimization) + // ========================================================================================= /** - * Update photo count cache (call after adding/removing images) + * Synchronizes the 'photoCount' denormalized field in the collections table with + * the actual count in the junction table. Should be called after image additions/removals. + * * @param updatedAt Timestamp of the operation. */ @Query(""" UPDATE collections @@ -188,6 +283,9 @@ interface CollectionDao { """) suspend fun updatePhotoCount(collectionId: String, updatedAt: Long) + /** + * Updates the thumbnail/cover image for the collection card. + */ @Query(""" UPDATE collections SET coverImageUri = :imageUri, updatedAt = :updatedAt @@ -195,6 +293,9 @@ interface CollectionDao { """) suspend fun updateCoverImage(collectionId: String, imageUri: String?, updatedAt: Long) + /** + * Toggles the pinned status of a collection. + */ @Query(""" UPDATE collections SET isPinned = :isPinned, updatedAt = :updatedAt @@ -202,6 +303,9 @@ interface CollectionDao { """) suspend fun updatePinned(collectionId: String, isPinned: Boolean, updatedAt: Long) + /** + * Updates the name and description of a collection. + */ @Query(""" UPDATE collections SET name = :name, description = :description, updatedAt = :updatedAt diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt new file mode 100644 index 0000000..6da2e0e --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt @@ -0,0 +1,134 @@ +package com.placeholder.sherpai2.data.local.dao + +import androidx.room.* +import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity + +/** + * FaceCacheDao - NO SOLO-PHOTO FILTER + * + * CRITICAL CHANGE: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * Removed all faceCount filters from queries + * + * WHY: + * - Group photos contain high-quality faces (especially for children) + * - IoU matching ensures we extract the CORRECT face from group photos + * - Rejecting group photos was eliminating 60-70% of quality faces! + * + * RESULT: + * - 2-3x more faces for clustering + * - Quality remains high (still filter by size + score) + * - Better clusters, especially for children + */ +@Dao +interface FaceCacheDao { + + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insert(faceCacheEntity: FaceCacheEntity) + + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insertAll(faceCacheEntities: List) + + @Update + suspend fun update(faceCacheEntity: FaceCacheEntity) + + /** + * Get ALL quality faces - INCLUDES GROUP PHOTOS! + * + * Quality filters (still strict): + * - faceAreaRatio >= minRatio (default 3% of image) + * - qualityScore >= minQuality (default 0.6 = 60%) + * - Has embedding + * + * NO faceCount filter! + */ + @Query(""" + SELECT fc.* + FROM face_cache fc + WHERE fc.faceAreaRatio >= :minRatio + AND fc.qualityScore >= :minQuality + AND fc.embedding IS NOT NULL + ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC + LIMIT :limit + """) + suspend fun getAllQualityFaces( + minRatio: Float = 0.03f, + minQuality: Float = 0.6f, + limit: Int = Int.MAX_VALUE + ): List + + /** + * Get quality faces WITHOUT embeddings - FOR PATH 2 + * + * These have good metadata but need embeddings generated. + * INCLUDES GROUP PHOTOS - IoU matching will handle extraction! + */ + @Query(""" + SELECT fc.* + FROM face_cache fc + WHERE fc.faceAreaRatio >= :minRatio + AND fc.qualityScore >= :minQuality + AND fc.embedding IS NULL + ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC + LIMIT :limit + """) + suspend fun getQualityFacesWithoutEmbeddings( + minRatio: Float = 0.03f, + minQuality: Float = 0.6f, + limit: Int = 5000 + ): List + + /** + * Count faces WITH embeddings (Path 1 check) + */ + @Query(""" + SELECT COUNT(*) + FROM face_cache + WHERE embedding IS NOT NULL + AND qualityScore >= :minQuality + """) + suspend fun countFacesWithEmbeddings(minQuality: Float = 0.6f): Int + + /** + * Count faces WITHOUT embeddings (Path 2 check) + */ + @Query(""" + SELECT COUNT(*) + FROM face_cache + WHERE embedding IS NULL + AND qualityScore >= :minQuality + """) + suspend fun countFacesWithoutEmbeddings(minQuality: Float = 0.6f): Int + + /** + * Get faces for specific image (for IoU matching) + */ + @Query("SELECT * FROM face_cache WHERE imageId = :imageId") + suspend fun getFaceCacheForImage(imageId: String): List + + /** + * Cache statistics + */ + @Query(""" + SELECT + COUNT(*) as totalFaces, + COUNT(CASE WHEN embedding IS NOT NULL THEN 1 END) as withEmbeddings, + AVG(qualityScore) as avgQuality, + AVG(faceAreaRatio) as avgSize + FROM face_cache + """) + suspend fun getCacheStats(): CacheStats + + @Query("DELETE FROM face_cache WHERE imageId = :imageId") + suspend fun deleteCacheForImage(imageId: String) + + @Query("DELETE FROM face_cache") + suspend fun deleteAll() +} + +data class CacheStats( + val totalFaces: Int, + val withEmbeddings: Int, + val avgQuality: Float, + val avgSize: Float +) \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/ImageDao.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/ImageDao.kt index 1bafe58..13e7595 100644 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/ImageDao.kt +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/ImageDao.kt @@ -297,6 +297,23 @@ interface ImageDao { """) suspend fun invalidateFaceDetectionCache(newVersion: Int) + /** + * Clear ALL face detection cache (force full rebuild). + * Sets all face detection fields to NULL for all images. + * + * Use this for "Force Rebuild Cache" button. + * This is different from invalidateFaceDetectionCache which only + * invalidates old versions - this clears EVERYTHING. + */ + @Query(""" + UPDATE images + SET hasFaces = NULL, + faceCount = NULL, + facesLastDetected = NULL, + faceDetectionVersion = NULL + """) + suspend fun clearAllFaceDetectionCache() + // ========================================== // STATISTICS QUERIES // ========================================== diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/PersonDao.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/PersonDao.kt index 994aada..34e13d9 100644 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/PersonDao.kt +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/PersonDao.kt @@ -48,4 +48,4 @@ interface PersonDao { @Query("SELECT EXISTS(SELECT 1 FROM persons WHERE id = :personId)") suspend fun personExists(personId: String): Boolean -} +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Personagetagdao.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Personagetagdao.kt new file mode 100644 index 0000000..d5ae390 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Personagetagdao.kt @@ -0,0 +1,104 @@ +package com.placeholder.sherpai2.data.local.dao + +import androidx.room.* +import com.placeholder.sherpai2.data.local.entity.PersonAgeTagEntity +import kotlinx.coroutines.flow.Flow + +/** + * PersonAgeTagDao - Manage searchable age tags for children + * + * USAGE EXAMPLES: + * - Search "emma age 3" → getImageIdsForTag("emma_age3") + * - Find all photos of Emma at age 5 → getImageIdsForPersonAtAge(emmaId, 5) + * - Get age progression → getTagsForPerson(emmaId) sorted by age + */ +@Dao +interface PersonAgeTagDao { + + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insertTag(tag: PersonAgeTagEntity) + + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insertTags(tags: List) + + /** + * Get all age tags for a person (sorted by age) + * Useful for age progression timeline + */ + @Query("SELECT * FROM person_age_tags WHERE personId = :personId ORDER BY ageAtCapture ASC") + suspend fun getTagsForPerson(personId: String): List + + /** + * Get all age tags for an image + */ + @Query("SELECT * FROM person_age_tags WHERE imageId = :imageId") + suspend fun getTagsForImage(imageId: String): List + + /** + * Search by tag value (e.g., "emma_age3") + * Returns all image IDs matching this tag + */ + @Query("SELECT DISTINCT imageId FROM person_age_tags WHERE tagValue = :tagValue") + suspend fun getImageIdsForTag(tagValue: String): List + + /** + * Get images of a person at a specific age + */ + @Query("SELECT DISTINCT imageId FROM person_age_tags WHERE personId = :personId AND ageAtCapture = :age") + suspend fun getImageIdsForPersonAtAge(personId: String, age: Int): List + + /** + * Get images of a person in an age range + */ + @Query(""" + SELECT DISTINCT imageId FROM person_age_tags + WHERE personId = :personId + AND ageAtCapture BETWEEN :minAge AND :maxAge + ORDER BY ageAtCapture ASC + """) + suspend fun getImageIdsForPersonAgeRange(personId: String, minAge: Int, maxAge: Int): List + + /** + * Get all unique ages for a person (for age picker UI) + */ + @Query("SELECT DISTINCT ageAtCapture FROM person_age_tags WHERE personId = :personId ORDER BY ageAtCapture ASC") + suspend fun getAgesForPerson(personId: String): List + + /** + * Delete all tags for a person + */ + @Query("DELETE FROM person_age_tags WHERE personId = :personId") + suspend fun deleteTagsForPerson(personId: String) + + /** + * Delete all tags for an image + */ + @Query("DELETE FROM person_age_tags WHERE imageId = :imageId") + suspend fun deleteTagsForImage(imageId: String) + + /** + * Get count of photos at each age (for statistics) + */ + @Query(""" + SELECT ageAtCapture, COUNT(DISTINCT imageId) as count + FROM person_age_tags + WHERE personId = :personId + GROUP BY ageAtCapture + ORDER BY ageAtCapture ASC + """) + suspend fun getPhotoCountByAge(personId: String): List + + /** + * Flow version for reactive UI + */ + @Query("SELECT * FROM person_age_tags WHERE personId = :personId ORDER BY ageAtCapture ASC") + fun getTagsForPersonFlow(personId: String): Flow> +} + +/** + * Data class for age photo count statistics + */ +data class AgePhotoCount( + val ageAtCapture: Int, + val count: Int +) \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Userfeedbackdao.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Userfeedbackdao.kt new file mode 100644 index 0000000..680e68b --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Userfeedbackdao.kt @@ -0,0 +1,212 @@ +package com.placeholder.sherpai2.data.local.dao + +import androidx.room.* +import com.placeholder.sherpai2.data.local.entity.FeedbackType +import com.placeholder.sherpai2.data.local.entity.UserFeedbackEntity +import kotlinx.coroutines.flow.Flow + +/** + * UserFeedbackDao - Query user corrections and feedback + * + * KEY QUERIES: + * - Get feedback for cluster validation + * - Find rejected faces to exclude from training + * - Track feedback statistics for quality metrics + * - Support cluster refinement workflow + */ +@Dao +interface UserFeedbackDao { + + // ═══════════════════════════════════════ + // INSERT / UPDATE + // ═══════════════════════════════════════ + + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insert(feedback: UserFeedbackEntity): Long + + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insertAll(feedbacks: List) + + @Update + suspend fun update(feedback: UserFeedbackEntity) + + @Delete + suspend fun delete(feedback: UserFeedbackEntity) + + // ═══════════════════════════════════════ + // CLUSTER VALIDATION QUERIES + // ═══════════════════════════════════════ + + /** + * Get all feedback for a cluster + * Used during validation to see what user has reviewed + */ + @Query("SELECT * FROM user_feedback WHERE clusterId = :clusterId ORDER BY timestamp DESC") + suspend fun getFeedbackForCluster(clusterId: Int): List + + /** + * Get rejected faces for a cluster + * These faces should be EXCLUDED from training + */ + @Query(""" + SELECT * FROM user_feedback + WHERE clusterId = :clusterId + AND feedbackType = 'REJECTED_MATCH' + """) + suspend fun getRejectedFacesForCluster(clusterId: Int): List + + /** + * Get confirmed faces for a cluster + * These faces are SAFE for training + */ + @Query(""" + SELECT * FROM user_feedback + WHERE clusterId = :clusterId + AND feedbackType = 'CONFIRMED_MATCH' + """) + suspend fun getConfirmedFacesForCluster(clusterId: Int): List + + /** + * Count feedback by type for a cluster + * Used to show stats: "15 confirmed, 3 rejected" + */ + @Query(""" + SELECT feedbackType, COUNT(*) as count + FROM user_feedback + WHERE clusterId = :clusterId + GROUP BY feedbackType + """) + suspend fun getFeedbackStatsByCluster(clusterId: Int): List + + // ═══════════════════════════════════════ + // PERSON FEEDBACK QUERIES + // ═══════════════════════════════════════ + + /** + * Get all feedback for a person + * Used to show history of corrections + */ + @Query("SELECT * FROM user_feedback WHERE personId = :personId ORDER BY timestamp DESC") + suspend fun getFeedbackForPerson(personId: String): List + + /** + * Get rejected faces for a person + * User said "this is NOT X" - exclude from model improvement + */ + @Query(""" + SELECT * FROM user_feedback + WHERE personId = :personId + AND feedbackType = 'REJECTED_MATCH' + """) + suspend fun getRejectedFacesForPerson(personId: String): List + + /** + * Flow version for reactive UI + */ + @Query("SELECT * FROM user_feedback WHERE personId = :personId ORDER BY timestamp DESC") + fun observeFeedbackForPerson(personId: String): Flow> + + // ═══════════════════════════════════════ + // IMAGE QUERIES + // ═══════════════════════════════════════ + + /** + * Get feedback for a specific image + */ + @Query("SELECT * FROM user_feedback WHERE imageId = :imageId") + suspend fun getFeedbackForImage(imageId: String): List + + /** + * Check if user has provided feedback for a specific face + */ + @Query(""" + SELECT EXISTS( + SELECT 1 FROM user_feedback + WHERE imageId = :imageId + AND faceIndex = :faceIndex + ) + """) + suspend fun hasFeedbackForFace(imageId: String, faceIndex: Int): Boolean + + // ═══════════════════════════════════════ + // STATISTICS & ANALYTICS + // ═══════════════════════════════════════ + + /** + * Get total feedback count + */ + @Query("SELECT COUNT(*) FROM user_feedback") + suspend fun getTotalFeedbackCount(): Int + + /** + * Get feedback count by type (global) + */ + @Query(""" + SELECT feedbackType, COUNT(*) as count + FROM user_feedback + GROUP BY feedbackType + """) + suspend fun getGlobalFeedbackStats(): List + + /** + * Get average original confidence for rejected faces + * Helps identify if low confidence → more rejections + */ + @Query(""" + SELECT AVG(originalConfidence) + FROM user_feedback + WHERE feedbackType = 'REJECTED_MATCH' + """) + suspend fun getAverageConfidenceForRejectedFaces(): Float? + + /** + * Find faces with low confidence that were confirmed + * These are "surprising successes" - model worked despite low confidence + */ + @Query(""" + SELECT * FROM user_feedback + WHERE feedbackType = 'CONFIRMED_MATCH' + AND originalConfidence < :threshold + ORDER BY originalConfidence ASC + """) + suspend fun getLowConfidenceSuccesses(threshold: Float = 0.7f): List + + // ═══════════════════════════════════════ + // CLEANUP + // ═══════════════════════════════════════ + + /** + * Delete all feedback for a cluster + * Called when cluster is deleted or refined + */ + @Query("DELETE FROM user_feedback WHERE clusterId = :clusterId") + suspend fun deleteFeedbackForCluster(clusterId: Int) + + /** + * Delete all feedback for a person + * Called when person is deleted + */ + @Query("DELETE FROM user_feedback WHERE personId = :personId") + suspend fun deleteFeedbackForPerson(personId: String) + + /** + * Delete old feedback (older than X days) + * Keep database size manageable + */ + @Query("DELETE FROM user_feedback WHERE timestamp < :cutoffTimestamp") + suspend fun deleteOldFeedback(cutoffTimestamp: Long) + + /** + * Clear all feedback (nuclear option) + */ + @Query("DELETE FROM user_feedback") + suspend fun deleteAll() +} + +/** + * Result class for feedback statistics + */ +data class FeedbackStat( + val feedbackType: String, + val count: Int +) \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Facecacheentity.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Facecacheentity.kt new file mode 100644 index 0000000..eca4f8c --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Facecacheentity.kt @@ -0,0 +1,156 @@ +package com.placeholder.sherpai2.data.local.entity + +import androidx.room.ColumnInfo +import androidx.room.Entity +import androidx.room.ForeignKey +import androidx.room.Index +import androidx.room.PrimaryKey +import java.util.UUID + +/** + * FaceCacheEntity - Per-face metadata for intelligent filtering + * + * PURPOSE: Store face quality metrics during initial cache population + * BENEFIT: Pre-filter to high-quality faces BEFORE clustering + * + * ENABLES QUERIES LIKE: + * - "Give me all solo photos with large, clear faces" + * - "Filter to faces that are > 15% of image" + * - "Exclude blurry/distant/profile faces" + * + * POPULATED BY: PopulateFaceDetectionCacheUseCase (enhanced version) + * USED BY: FaceClusteringService for smart photo selection + */ +@Entity( + tableName = "face_cache", + foreignKeys = [ + ForeignKey( + entity = ImageEntity::class, + parentColumns = ["imageId"], + childColumns = ["imageId"], + onDelete = ForeignKey.CASCADE + ) + ], + indices = [ + Index(value = ["imageId"]), + Index(value = ["faceIndex"]), + Index(value = ["faceAreaRatio"]), + Index(value = ["qualityScore"]), + Index(value = ["imageId", "faceIndex"], unique = true) + ] +) +data class FaceCacheEntity( + @PrimaryKey + @ColumnInfo(name = "id") + val id: String = UUID.randomUUID().toString(), + + @ColumnInfo(name = "imageId") + val imageId: String, + + @ColumnInfo(name = "faceIndex") + val faceIndex: Int, // 0-based index for multiple faces in image + + // FACE METRICS (for filtering) + @ColumnInfo(name = "boundingBox") + val boundingBox: String, // "left,top,right,bottom" + + @ColumnInfo(name = "faceWidth") + val faceWidth: Int, // pixels + + @ColumnInfo(name = "faceHeight") + val faceHeight: Int, // pixels + + @ColumnInfo(name = "faceAreaRatio") + val faceAreaRatio: Float, // face area / image area (0.0 - 1.0) + + @ColumnInfo(name = "imageWidth") + val imageWidth: Int, // Full image width + + @ColumnInfo(name = "imageHeight") + val imageHeight: Int, // Full image height + + // QUALITY INDICATORS + @ColumnInfo(name = "qualityScore") + val qualityScore: Float, // 0.0-1.0 (combines size + clarity + angle) + + @ColumnInfo(name = "isLargeEnough") + val isLargeEnough: Boolean, // faceAreaRatio >= 0.15 AND min 200x200px + + @ColumnInfo(name = "isFrontal") + val isFrontal: Boolean, // Face angle roughly frontal (from ML Kit) + + @ColumnInfo(name = "hasGoodLighting") + val hasGoodLighting: Boolean, // Not too dark/bright (heuristic) + + // EMBEDDING (optional - for super fast clustering) + @ColumnInfo(name = "embedding") + val embedding: String?, // Pre-computed 192D embedding (comma-separated) + + // METADATA + @ColumnInfo(name = "confidence") + val confidence: Float, // ML Kit detection confidence + + @ColumnInfo(name = "detectedAt") + val detectedAt: Long = System.currentTimeMillis(), + + @ColumnInfo(name = "cacheVersion") + val cacheVersion: Int = CURRENT_CACHE_VERSION +) { + companion object { + const val CURRENT_CACHE_VERSION = 1 + + /** + * Create from ML Kit face detection result + */ + fun create( + imageId: String, + faceIndex: Int, + boundingBox: android.graphics.Rect, + imageWidth: Int, + imageHeight: Int, + confidence: Float, + isFrontal: Boolean, + embedding: FloatArray? = null + ): FaceCacheEntity { + val faceWidth = boundingBox.width() + val faceHeight = boundingBox.height() + val faceArea = faceWidth * faceHeight + val imageArea = imageWidth * imageHeight + val faceAreaRatio = faceArea.toFloat() / imageArea.toFloat() + + // Calculate quality score + val sizeScore = (faceAreaRatio * 5).coerceIn(0f, 1f) // 20% = perfect + val pixelScore = if (faceWidth >= 200 && faceHeight >= 200) 1f else 0.5f + val angleScore = if (isFrontal) 1f else 0.7f + val qualityScore = (sizeScore + pixelScore + angleScore) / 3f + + val isLargeEnough = faceAreaRatio >= 0.15f && faceWidth >= 200 && faceHeight >= 200 + + return FaceCacheEntity( + imageId = imageId, + faceIndex = faceIndex, + boundingBox = "${boundingBox.left},${boundingBox.top},${boundingBox.right},${boundingBox.bottom}", + faceWidth = faceWidth, + faceHeight = faceHeight, + faceAreaRatio = faceAreaRatio, + imageWidth = imageWidth, + imageHeight = imageHeight, + qualityScore = qualityScore, + isLargeEnough = isLargeEnough, + isFrontal = isFrontal, + hasGoodLighting = true, // TODO: Implement lighting analysis + embedding = embedding?.joinToString(","), + confidence = confidence + ) + } + } + + fun getBoundingBox(): android.graphics.Rect { + val parts = boundingBox.split(",").map { it.toInt() } + return android.graphics.Rect(parts[0], parts[1], parts[2], parts[3]) + } + + fun getEmbedding(): FloatArray? { + return embedding?.split(",")?.map { it.toFloat() }?.toFloatArray() + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Facerecognitionentities.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Facerecognitionentities.kt index 0bddef8..64a8e8e 100644 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Facerecognitionentities.kt +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Facerecognitionentities.kt @@ -5,19 +5,24 @@ import androidx.room.Entity import androidx.room.ForeignKey import androidx.room.Index import androidx.room.PrimaryKey +import org.json.JSONArray +import org.json.JSONObject import java.util.UUID /** - * PersonEntity - NO DEFAULT VALUES for KSP compatibility + * PersonEntity - ENHANCED with child tracking and sibling relationships */ @Entity( tableName = "persons", - indices = [Index(value = ["name"])] + indices = [ + Index(value = ["name"]), + Index(value = ["familyGroupId"]) + ] ) data class PersonEntity( @PrimaryKey @ColumnInfo(name = "id") - val id: String, // ← No default + val id: String, @ColumnInfo(name = "name") val name: String, @@ -25,26 +30,48 @@ data class PersonEntity( @ColumnInfo(name = "dateOfBirth") val dateOfBirth: Long?, + @ColumnInfo(name = "isChild") + val isChild: Boolean, // NEW: Auto-set based on age + + @ColumnInfo(name = "siblingIds") + val siblingIds: String?, // NEW: JSON list ["uuid1", "uuid2"] + + @ColumnInfo(name = "familyGroupId") + val familyGroupId: String?, // NEW: UUID for family unit + @ColumnInfo(name = "relationship") val relationship: String?, @ColumnInfo(name = "createdAt") - val createdAt: Long, // ← No default + val createdAt: Long, @ColumnInfo(name = "updatedAt") - val updatedAt: Long // ← No default + val updatedAt: Long ) { companion object { fun create( name: String, dateOfBirth: Long? = null, + isChild: Boolean = false, + siblingIds: List = emptyList(), relationship: String? = null ): PersonEntity { val now = System.currentTimeMillis() + + // Create family group if siblings exist + val familyGroupId = if (siblingIds.isNotEmpty()) { + UUID.randomUUID().toString() + } else null + return PersonEntity( id = UUID.randomUUID().toString(), name = name, dateOfBirth = dateOfBirth, + isChild = isChild, + siblingIds = if (siblingIds.isNotEmpty()) { + JSONArray(siblingIds).toString() + } else null, + familyGroupId = familyGroupId, relationship = relationship, createdAt = now, updatedAt = now @@ -52,6 +79,17 @@ data class PersonEntity( } } + fun getSiblingIds(): List { + return if (siblingIds != null) { + try { + val jsonArray = JSONArray(siblingIds) + (0 until jsonArray.length()).map { jsonArray.getString(it) } + } catch (e: Exception) { + emptyList() + } + } else emptyList() + } + fun getAge(): Int? { if (dateOfBirth == null) return null val now = System.currentTimeMillis() @@ -74,7 +112,7 @@ data class PersonEntity( } /** - * FaceModelEntity - NO DEFAULT VALUES + * FaceModelEntity - MULTI-CENTROID support for temporal tracking */ @Entity( tableName = "face_models", @@ -91,13 +129,13 @@ data class PersonEntity( data class FaceModelEntity( @PrimaryKey @ColumnInfo(name = "id") - val id: String, // ← No default + val id: String, @ColumnInfo(name = "personId") val personId: String, - @ColumnInfo(name = "embedding") - val embedding: String, + @ColumnInfo(name = "centroidsJson") + val centroidsJson: String, // NEW: List as JSON @ColumnInfo(name = "trainingImageCount") val trainingImageCount: Int, @@ -106,10 +144,10 @@ data class FaceModelEntity( val averageConfidence: Float, @ColumnInfo(name = "createdAt") - val createdAt: Long, // ← No default + val createdAt: Long, @ColumnInfo(name = "updatedAt") - val updatedAt: Long, // ← No default + val updatedAt: Long, @ColumnInfo(name = "lastUsed") val lastUsed: Long?, @@ -118,17 +156,42 @@ data class FaceModelEntity( val isActive: Boolean ) { companion object { + /** + * Backwards compatible create() method + * Used by existing FaceRecognitionRepository code + */ fun create( personId: String, embeddingArray: FloatArray, trainingImageCount: Int, averageConfidence: Float + ): FaceModelEntity { + return createFromEmbedding(personId, embeddingArray, trainingImageCount, averageConfidence) + } + + /** + * Create from single embedding (backwards compatible) + */ + fun createFromEmbedding( + personId: String, + embeddingArray: FloatArray, + trainingImageCount: Int, + averageConfidence: Float ): FaceModelEntity { val now = System.currentTimeMillis() + val centroid = TemporalCentroid( + embedding = embeddingArray.toList(), + effectiveTimestamp = now, + ageAtCapture = null, + photoCount = trainingImageCount, + timeRangeMonths = 12, + avgConfidence = averageConfidence + ) + return FaceModelEntity( id = UUID.randomUUID().toString(), personId = personId, - embedding = embeddingArray.joinToString(","), + centroidsJson = serializeCentroids(listOf(centroid)), trainingImageCount = trainingImageCount, averageConfidence = averageConfidence, createdAt = now, @@ -137,15 +200,106 @@ data class FaceModelEntity( isActive = true ) } + + /** + * Create from multiple centroids (temporal tracking) + */ + fun createFromCentroids( + personId: String, + centroids: List, + trainingImageCount: Int, + averageConfidence: Float + ): FaceModelEntity { + val now = System.currentTimeMillis() + return FaceModelEntity( + id = UUID.randomUUID().toString(), + personId = personId, + centroidsJson = serializeCentroids(centroids), + trainingImageCount = trainingImageCount, + averageConfidence = averageConfidence, + createdAt = now, + updatedAt = now, + lastUsed = null, + isActive = true + ) + } + + /** + * Serialize list of centroids to JSON + */ + private fun serializeCentroids(centroids: List): String { + val jsonArray = JSONArray() + centroids.forEach { centroid -> + val jsonObj = JSONObject() + jsonObj.put("embedding", JSONArray(centroid.embedding)) + jsonObj.put("effectiveTimestamp", centroid.effectiveTimestamp) + jsonObj.put("ageAtCapture", centroid.ageAtCapture) + jsonObj.put("photoCount", centroid.photoCount) + jsonObj.put("timeRangeMonths", centroid.timeRangeMonths) + jsonObj.put("avgConfidence", centroid.avgConfidence) + jsonArray.put(jsonObj) + } + return jsonArray.toString() + } + + /** + * Deserialize JSON to list of centroids + */ + private fun deserializeCentroids(json: String): List { + val jsonArray = JSONArray(json) + return (0 until jsonArray.length()).map { i -> + val jsonObj = jsonArray.getJSONObject(i) + val embeddingArray = jsonObj.getJSONArray("embedding") + val embedding = (0 until embeddingArray.length()).map { j -> + embeddingArray.getDouble(j).toFloat() + } + TemporalCentroid( + embedding = embedding, + effectiveTimestamp = jsonObj.getLong("effectiveTimestamp"), + ageAtCapture = if (jsonObj.isNull("ageAtCapture")) null else jsonObj.getDouble("ageAtCapture").toFloat(), + photoCount = jsonObj.getInt("photoCount"), + timeRangeMonths = jsonObj.getInt("timeRangeMonths"), + avgConfidence = jsonObj.getDouble("avgConfidence").toFloat() + ) + } + } } + fun getCentroids(): List { + return try { + FaceModelEntity.deserializeCentroids(centroidsJson) + } catch (e: Exception) { + emptyList() + } + } + + // Backwards compatibility: get first centroid as single embedding fun getEmbeddingArray(): FloatArray { - return embedding.split(",").map { it.toFloat() }.toFloatArray() + val centroids = getCentroids() + return if (centroids.isNotEmpty()) { + centroids.first().getEmbeddingArray() + } else { + FloatArray(192) // Empty embedding + } } } /** - * PhotoFaceTagEntity - NO DEFAULT VALUES + * TemporalCentroid - Represents a face appearance at a specific time period + */ +data class TemporalCentroid( + val embedding: List, // 192D vector + val effectiveTimestamp: Long, // Center of time window + val ageAtCapture: Float?, // Age in years (for children) + val photoCount: Int, // Number of photos in this cluster + val timeRangeMonths: Int, // Width of time window (e.g., 6 months) + val avgConfidence: Float // Quality indicator +) { + fun getEmbeddingArray(): FloatArray = embedding.toFloatArray() +} + +/** + * PhotoFaceTagEntity - Unchanged */ @Entity( tableName = "photo_face_tags", @@ -172,7 +326,7 @@ data class FaceModelEntity( data class PhotoFaceTagEntity( @PrimaryKey @ColumnInfo(name = "id") - val id: String, // ← No default + val id: String, @ColumnInfo(name = "imageId") val imageId: String, @@ -190,7 +344,7 @@ data class PhotoFaceTagEntity( val embedding: String, @ColumnInfo(name = "detectedAt") - val detectedAt: Long, // ← No default + val detectedAt: Long, @ColumnInfo(name = "verifiedByUser") val verifiedByUser: Boolean, @@ -228,4 +382,74 @@ data class PhotoFaceTagEntity( fun getEmbeddingArray(): FloatArray { return embedding.split(",").map { it.toFloat() }.toFloatArray() } +} + +/** + * PersonAgeTagEntity - NEW: Searchable age tags + */ +@Entity( + tableName = "person_age_tags", + foreignKeys = [ + ForeignKey( + entity = PersonEntity::class, + parentColumns = ["id"], + childColumns = ["personId"], + onDelete = ForeignKey.CASCADE + ), + ForeignKey( + entity = ImageEntity::class, + parentColumns = ["imageId"], + childColumns = ["imageId"], + onDelete = ForeignKey.CASCADE + ) + ], + indices = [ + Index(value = ["personId"]), + Index(value = ["imageId"]), + Index(value = ["ageAtCapture"]), + Index(value = ["tagValue"]) + ] +) +data class PersonAgeTagEntity( + @PrimaryKey + @ColumnInfo(name = "id") + val id: String, + + @ColumnInfo(name = "personId") + val personId: String, + + @ColumnInfo(name = "imageId") + val imageId: String, + + @ColumnInfo(name = "ageAtCapture") + val ageAtCapture: Int, + + @ColumnInfo(name = "tagValue") + val tagValue: String, // e.g., "emma_age3" + + @ColumnInfo(name = "confidence") + val confidence: Float, + + @ColumnInfo(name = "createdAt") + val createdAt: Long +) { + companion object { + fun create( + personId: String, + personName: String, + imageId: String, + ageAtCapture: Int, + confidence: Float + ): PersonAgeTagEntity { + return PersonAgeTagEntity( + id = UUID.randomUUID().toString(), + personId = personId, + imageId = imageId, + ageAtCapture = ageAtCapture, + tagValue = "${personName.lowercase().replace(" ", "_")}_age$ageAtCapture", + confidence = confidence, + createdAt = System.currentTimeMillis() + ) + } + } } \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Userfeedbackentity.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Userfeedbackentity.kt new file mode 100644 index 0000000..e48313e --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Userfeedbackentity.kt @@ -0,0 +1,161 @@ +package com.placeholder.sherpai2.data.local.entity + +import androidx.room.Entity +import androidx.room.ForeignKey +import androidx.room.Index +import androidx.room.PrimaryKey +import java.util.UUID + +/** + * UserFeedbackEntity - Stores user corrections during cluster validation + * + * PURPOSE: + * - Capture which faces user marked as correct/incorrect + * - Ground truth data for improving clustering + * - Enable cluster refinement before training + * - Track confidence in automated detections + * + * USAGE FLOW: + * 1. Clustering creates initial clusters + * 2. User reviews ValidationPreview + * 3. User swipes faces: ✅ Correct / ❌ Incorrect + * 4. Feedback stored here + * 5. If too many incorrect → Re-cluster without those faces + * 6. If approved → Train model with confirmed faces only + * + * INDEXES: + * - imageId: Fast lookup of feedback for specific images + * - clusterId: Get all feedback for a cluster + * - feedbackType: Filter by correction type + * - personId: Track feedback after person created + */ +@Entity( + tableName = "user_feedback", + foreignKeys = [ + ForeignKey( + entity = ImageEntity::class, + parentColumns = ["imageId"], + childColumns = ["imageId"], + onDelete = ForeignKey.CASCADE + ), + ForeignKey( + entity = PersonEntity::class, + parentColumns = ["id"], + childColumns = ["personId"], + onDelete = ForeignKey.CASCADE + ) + ], + indices = [ + Index(value = ["imageId"]), + Index(value = ["clusterId"]), + Index(value = ["personId"]), + Index(value = ["feedbackType"]) + ] +) +data class UserFeedbackEntity( + @PrimaryKey + val id: String = UUID.randomUUID().toString(), + + /** + * Image containing the face + */ + val imageId: String, + + /** + * Face index within the image (0-based) + * Multiple faces per image possible + */ + val faceIndex: Int, + + /** + * Cluster ID from clustering (before person created) + * Null if feedback given after person exists + */ + val clusterId: Int?, + + /** + * Person ID if feedback is about an existing person + * Null during initial cluster validation + */ + val personId: String?, + + /** + * Type of feedback user provided + */ + val feedbackType: String, // FeedbackType enum stored as string + + /** + * Confidence score that led to this face being shown + * Helps identify if low confidence = more errors + */ + val originalConfidence: Float, + + /** + * Optional user note + */ + val userNote: String? = null, + + /** + * When feedback was provided + */ + val timestamp: Long = System.currentTimeMillis() +) { + companion object { + fun create( + imageId: String, + faceIndex: Int, + clusterId: Int? = null, + personId: String? = null, + feedbackType: FeedbackType, + originalConfidence: Float, + userNote: String? = null + ): UserFeedbackEntity { + return UserFeedbackEntity( + imageId = imageId, + faceIndex = faceIndex, + clusterId = clusterId, + personId = personId, + feedbackType = feedbackType.name, + originalConfidence = originalConfidence, + userNote = userNote + ) + } + } + + fun getFeedbackType(): FeedbackType { + return try { + FeedbackType.valueOf(feedbackType) + } catch (e: Exception) { + FeedbackType.UNCERTAIN + } + } +} + +/** + * FeedbackType - Types of user corrections + */ +enum class FeedbackType { + /** + * User confirmed this face IS the person + * Boosts confidence, use for training + */ + CONFIRMED_MATCH, + + /** + * User said this face is NOT the person + * Remove from cluster, exclude from training + */ + REJECTED_MATCH, + + /** + * User marked as outlier during cluster review + * Face doesn't belong in this cluster + */ + MARKED_OUTLIER, + + /** + * User is uncertain + * Skip this face for training, revisit later + */ + UNCERTAIN +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/di/DatabaseModule.kt b/app/src/main/java/com/placeholder/sherpai2/di/DatabaseModule.kt index c7837e0..021e088 100644 --- a/app/src/main/java/com/placeholder/sherpai2/di/DatabaseModule.kt +++ b/app/src/main/java/com/placeholder/sherpai2/di/DatabaseModule.kt @@ -3,6 +3,9 @@ package com.placeholder.sherpai2.di import android.content.Context import androidx.room.Room import com.placeholder.sherpai2.data.local.AppDatabase +import com.placeholder.sherpai2.data.local.MIGRATION_7_8 +import com.placeholder.sherpai2.data.local.MIGRATION_8_9 +import com.placeholder.sherpai2.data.local.MIGRATION_9_10 import com.placeholder.sherpai2.data.local.dao.* import dagger.Module import dagger.Provides @@ -14,9 +17,17 @@ import javax.inject.Singleton /** * DatabaseModule - Provides database and ALL DAOs * - * DEVELOPMENT CONFIGURATION: - * - fallbackToDestructiveMigration enabled - * - No migrations required + * VERSION 10 UPDATES: + * - Added UserFeedbackDao for cluster refinement + * - Added MIGRATION_9_10 + * + * VERSION 9 UPDATES: + * - Added FaceCacheDao for per-face metadata + * - Added MIGRATION_8_9 + * + * PHASE 2 UPDATES: + * - Added PersonAgeTagDao + * - Added migration v7→v8 */ @Module @InstallIn(SingletonComponent::class) @@ -34,7 +45,12 @@ object DatabaseModule { AppDatabase::class.java, "sherpai.db" ) - .fallbackToDestructiveMigration() + // DEVELOPMENT MODE: Destructive migration (fresh install on schema change) + .fallbackToDestructiveMigration(dropAllTables = true) + + // PRODUCTION MODE: Uncomment this and remove fallbackToDestructiveMigration() + // .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10) + .build() // ===== CORE DAOs ===== @@ -77,8 +93,21 @@ object DatabaseModule { fun providePhotoFaceTagDao(db: AppDatabase): PhotoFaceTagDao = db.photoFaceTagDao() + @Provides + fun providePersonAgeTagDao(db: AppDatabase): PersonAgeTagDao = + db.personAgeTagDao() + + @Provides + fun provideFaceCacheDao(db: AppDatabase): FaceCacheDao = + db.faceCacheDao() + + @Provides + fun provideUserFeedbackDao(db: AppDatabase): UserFeedbackDao = + db.userFeedbackDao() + // ===== COLLECTIONS DAOs ===== + @Provides fun provideCollectionDao(db: AppDatabase): CollectionDao = db.collectionDao() -} +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/di/RepositoryModule.kt b/app/src/main/java/com/placeholder/sherpai2/di/RepositoryModule.kt index 2e97da7..b57fc34 100644 --- a/app/src/main/java/com/placeholder/sherpai2/di/RepositoryModule.kt +++ b/app/src/main/java/com/placeholder/sherpai2/di/RepositoryModule.kt @@ -1,15 +1,16 @@ package com.placeholder.sherpai2.di import android.content.Context -import com.placeholder.sherpai2.data.local.dao.FaceModelDao -import com.placeholder.sherpai2.data.local.dao.ImageDao -import com.placeholder.sherpai2.data.local.dao.PersonDao -import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao +import androidx.work.WorkManager +import com.placeholder.sherpai2.data.local.dao.* import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository import com.placeholder.sherpai2.data.repository.TaggingRepositoryImpl +import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer +import com.placeholder.sherpai2.domain.clustering.ClusterRefinementService import com.placeholder.sherpai2.domain.repository.ImageRepository import com.placeholder.sherpai2.domain.repository.ImageRepositoryImpl import com.placeholder.sherpai2.domain.repository.TaggingRepository +import com.placeholder.sherpai2.domain.validation.ValidationScanService import dagger.Binds import dagger.Module import dagger.Provides @@ -23,6 +24,10 @@ import javax.inject.Singleton * * UPDATED TO INCLUDE: * - FaceRecognitionRepository for face recognition operations + * - ValidationScanService for post-training validation + * - ClusterRefinementService for user feedback loop (NEW) + * - ClusterQualityAnalyzer for cluster validation + * - WorkManager for background tasks */ @Module @InstallIn(SingletonComponent::class) @@ -48,26 +53,6 @@ abstract class RepositoryModule { /** * Provide FaceRecognitionRepository - * - * Uses @Provides instead of @Binds because it needs Context parameter - * and multiple DAO dependencies - * - * INJECTED DEPENDENCIES: - * - Context: For FaceNetModel initialization - * - PersonDao: Access existing persons - * - ImageDao: Access existing images - * - FaceModelDao: Manage face models - * - PhotoFaceTagDao: Manage photo tags - * - * USAGE IN VIEWMODEL: - * ``` - * @HiltViewModel - * class MyViewModel @Inject constructor( - * private val faceRecognitionRepository: FaceRecognitionRepository - * ) : ViewModel() { - * // Use repository methods - * } - * ``` */ @Provides @Singleton @@ -86,5 +71,61 @@ abstract class RepositoryModule { photoFaceTagDao = photoFaceTagDao ) } + + /** + * Provide ValidationScanService + */ + @Provides + @Singleton + fun provideValidationScanService( + @ApplicationContext context: Context, + imageDao: ImageDao, + faceModelDao: FaceModelDao + ): ValidationScanService { + return ValidationScanService( + context = context, + imageDao = imageDao, + faceModelDao = faceModelDao + ) + } + + /** + * Provide ClusterRefinementService (NEW) + * Handles user feedback and cluster refinement workflow + */ + @Provides + @Singleton + fun provideClusterRefinementService( + faceCacheDao: FaceCacheDao, + userFeedbackDao: UserFeedbackDao, + qualityAnalyzer: ClusterQualityAnalyzer + ): ClusterRefinementService { + return ClusterRefinementService( + faceCacheDao = faceCacheDao, + userFeedbackDao = userFeedbackDao, + qualityAnalyzer = qualityAnalyzer + ) + } + + /** + * Provide ClusterQualityAnalyzer + * Validates cluster quality before training + */ + @Provides + @Singleton + fun provideClusterQualityAnalyzer(): ClusterQualityAnalyzer { + return ClusterQualityAnalyzer() + } + + /** + * Provide WorkManager for background tasks + */ + @Provides + @Singleton + fun provideWorkManager( + @ApplicationContext context: Context + ): WorkManager { + return WorkManager.getInstance(context) + } } } \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Clusterqualityanalyzer.kt b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Clusterqualityanalyzer.kt new file mode 100644 index 0000000..4241332 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Clusterqualityanalyzer.kt @@ -0,0 +1,285 @@ +package com.placeholder.sherpai2.domain.clustering + +import android.graphics.Rect +import android.util.Log +import javax.inject.Inject +import javax.inject.Singleton +import kotlin.math.sqrt + +/** + * ClusterQualityAnalyzer - Validate cluster quality BEFORE training + * + * RELAXED THRESHOLDS for real-world photos (social media, distant shots): + * - Face size: 3% (down from 15%) + * - Outlier threshold: 65% (down from 75%) + * - GOOD tier: 75% (down from 85%) + * - EXCELLENT tier: 85% (down from 95%) + */ +@Singleton +class ClusterQualityAnalyzer @Inject constructor() { + + companion object { + private const val TAG = "ClusterQuality" + private const val MIN_SOLO_PHOTOS = 6 + private const val MIN_FACE_SIZE_RATIO = 0.03f // 3% of image (RELAXED) + private const val MIN_FACE_DIMENSION_PIXELS = 50 // 50px minimum (RELAXED) + private const val FALLBACK_MIN_DIMENSION = 80 // Fallback when no dimensions + private const val MIN_INTERNAL_SIMILARITY = 0.75f + private const val OUTLIER_THRESHOLD = 0.65f // RELAXED + private const val EXCELLENT_THRESHOLD = 0.85f // RELAXED + private const val GOOD_THRESHOLD = 0.75f // RELAXED + } + + fun analyzeCluster(cluster: FaceCluster): ClusterQualityResult { + Log.d(TAG, "========================================") + Log.d(TAG, "Analyzing cluster ${cluster.clusterId}") + Log.d(TAG, "Total faces: ${cluster.faces.size}") + + // Step 1: Filter to solo photos + val soloFaces = cluster.faces.filter { it.faceCount == 1 } + Log.d(TAG, "Solo photos: ${soloFaces.size}") + + // Step 2: Filter by face size + val largeFaces = soloFaces.filter { face -> + isFaceLargeEnough(face) + } + Log.d(TAG, "Large faces (>= 3%): ${largeFaces.size}") + + if (largeFaces.size < soloFaces.size) { + Log.d(TAG, "⚠️ Filtered out ${soloFaces.size - largeFaces.size} small faces") + } + + // Step 3: Calculate internal consistency + val (avgSimilarity, outliers) = analyzeInternalConsistency(largeFaces) + + // Step 4: Clean faces + val cleanFaces = largeFaces.filter { it !in outliers } + Log.d(TAG, "Clean faces: ${cleanFaces.size}") + + // Step 5: Calculate quality score + val qualityScore = calculateQualityScore( + soloPhotoCount = soloFaces.size, + largeFaceCount = largeFaces.size, + cleanFaceCount = cleanFaces.size, + avgSimilarity = avgSimilarity, + totalFaces = cluster.faces.size + ) + Log.d(TAG, "Quality score: ${(qualityScore * 100).toInt()}%") + + // Step 6: Determine quality tier + val qualityTier = when { + qualityScore >= EXCELLENT_THRESHOLD -> ClusterQualityTier.EXCELLENT + qualityScore >= GOOD_THRESHOLD -> ClusterQualityTier.GOOD + else -> ClusterQualityTier.POOR + } + Log.d(TAG, "Quality tier: $qualityTier") + + val canTrain = qualityTier != ClusterQualityTier.POOR && cleanFaces.size >= MIN_SOLO_PHOTOS + Log.d(TAG, "Can train: $canTrain") + Log.d(TAG, "========================================") + + return ClusterQualityResult( + originalFaceCount = cluster.faces.size, + soloPhotoCount = soloFaces.size, + largeFaceCount = largeFaces.size, + cleanFaceCount = cleanFaces.size, + avgInternalSimilarity = avgSimilarity, + outlierFaces = outliers, + cleanFaces = cleanFaces, + qualityScore = qualityScore, + qualityTier = qualityTier, + canTrain = canTrain, + warnings = generateWarnings(soloFaces.size, largeFaces.size, cleanFaces.size, qualityTier, avgSimilarity) + ) + } + + private fun isFaceLargeEnough(face: DetectedFaceWithEmbedding): Boolean { + val faceArea = face.boundingBox.width() * face.boundingBox.height() + + // Check 1: Absolute minimum + if (face.boundingBox.width() < MIN_FACE_DIMENSION_PIXELS || + face.boundingBox.height() < MIN_FACE_DIMENSION_PIXELS) { + return false + } + + // Check 2: Relative size if we have dimensions + if (face.imageWidth > 0 && face.imageHeight > 0) { + val imageArea = face.imageWidth * face.imageHeight + val faceRatio = faceArea.toFloat() / imageArea.toFloat() + return faceRatio >= MIN_FACE_SIZE_RATIO + } + + // Fallback: Use absolute size + return face.boundingBox.width() >= FALLBACK_MIN_DIMENSION && + face.boundingBox.height() >= FALLBACK_MIN_DIMENSION + } + + private fun analyzeInternalConsistency( + faces: List + ): Pair> { + if (faces.size < 2) { + Log.d(TAG, "Less than 2 faces, skipping consistency check") + return 0f to emptyList() + } + + Log.d(TAG, "Analyzing ${faces.size} faces for internal consistency") + + val centroid = calculateCentroid(faces.map { it.embedding }) + + val centroidSum = centroid.sum() + Log.d(TAG, "Centroid sum: $centroidSum, first5=[${centroid.take(5).joinToString()}]") + + val similarities = faces.mapIndexed { index, face -> + val similarity = cosineSimilarity(face.embedding, centroid) + Log.d(TAG, "Face $index similarity to centroid: $similarity") + face to similarity + } + + val avgSimilarity = similarities.map { it.second }.average().toFloat() + Log.d(TAG, "Average internal similarity: $avgSimilarity") + + val outliers = similarities + .filter { (_, similarity) -> similarity < OUTLIER_THRESHOLD } + .map { (face, _) -> face } + + Log.d(TAG, "Found ${outliers.size} outliers (threshold=$OUTLIER_THRESHOLD)") + + return avgSimilarity to outliers + } + + private fun calculateCentroid(embeddings: List): FloatArray { + val size = embeddings.first().size + val centroid = FloatArray(size) { 0f } + + embeddings.forEach { embedding -> + for (i in embedding.indices) { + centroid[i] += embedding[i] + } + } + + val count = embeddings.size.toFloat() + for (i in centroid.indices) { + centroid[i] /= count + } + + val norm = sqrt(centroid.map { it * it }.sum()) + return if (norm > 0) { + centroid.map { it / norm }.toFloatArray() + } else { + centroid + } + } + + private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float { + var dotProduct = 0f + var normA = 0f + var normB = 0f + + for (i in a.indices) { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + + return dotProduct / (sqrt(normA) * sqrt(normB)) + } + + private fun calculateQualityScore( + soloPhotoCount: Int, + largeFaceCount: Int, + cleanFaceCount: Int, + avgSimilarity: Float, + totalFaces: Int + ): Float { + val soloRatio = soloPhotoCount.toFloat() / totalFaces.toFloat().coerceAtLeast(1f) + val soloPhotoScore = soloRatio.coerceIn(0f, 1f) * 0.25f + + val largeFaceScore = (largeFaceCount.toFloat() / 15f).coerceIn(0f, 1f) * 0.25f + + val cleanFaceScore = (cleanFaceCount.toFloat() / 10f).coerceIn(0f, 1f) * 0.20f + + val similarityScore = avgSimilarity * 0.30f + + return soloPhotoScore + largeFaceScore + cleanFaceScore + similarityScore + } + + private fun generateWarnings( + soloPhotoCount: Int, + largeFaceCount: Int, + cleanFaceCount: Int, + qualityTier: ClusterQualityTier, + avgSimilarity: Float + ): List { + val warnings = mutableListOf() + + when (qualityTier) { + ClusterQualityTier.POOR -> { + warnings.add("⚠️ POOR QUALITY - This cluster may contain multiple people!") + warnings.add("Do NOT train on this cluster - it will create a bad model.") + + if (avgSimilarity < 0.70f) { + warnings.add("Low internal similarity (${(avgSimilarity * 100).toInt()}%) suggests mixed identities.") + } + } + ClusterQualityTier.GOOD -> { + warnings.add("⚠️ Review outlier faces before training") + + if (cleanFaceCount < 10) { + warnings.add("Consider adding more high-quality photos for better results.") + } + } + ClusterQualityTier.EXCELLENT -> { + // No warnings + } + } + + if (soloPhotoCount < MIN_SOLO_PHOTOS) { + warnings.add("Need at least $MIN_SOLO_PHOTOS solo photos (have $soloPhotoCount)") + } + + if (largeFaceCount < 6) { + warnings.add("Only $largeFaceCount photos with large/clear faces (prefer 10+)") + warnings.add("Tip: Use close-up photos where the face is clearly visible") + } + + if (cleanFaceCount < 6) { + warnings.add("After removing outliers: only $cleanFaceCount clean faces (need 6+)") + } + + if (qualityTier == ClusterQualityTier.EXCELLENT) { + warnings.add("✅ Excellent quality! This cluster is ready for training.") + warnings.add("High-quality photos with consistent facial features detected.") + } + + return warnings + } +} + +data class ClusterQualityResult( + val originalFaceCount: Int, + val soloPhotoCount: Int, + val largeFaceCount: Int, + val cleanFaceCount: Int, + val avgInternalSimilarity: Float, + val outlierFaces: List, + val cleanFaces: List, + val qualityScore: Float, + val qualityTier: ClusterQualityTier, + val canTrain: Boolean, + val warnings: List +) { + fun getSummary(): String = when (qualityTier) { + ClusterQualityTier.EXCELLENT -> + "Excellent quality cluster with $cleanFaceCount high-quality photos ready for training." + ClusterQualityTier.GOOD -> + "Good quality cluster with $cleanFaceCount usable photos. Review outliers before training." + ClusterQualityTier.POOR -> + "Poor quality cluster. May contain multiple people or low-quality photos. Add more photos or split cluster." + } +} + +enum class ClusterQualityTier { + EXCELLENT, // 85%+ + GOOD, // 75-84% + POOR // <75% +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Clusterrefinementservice.kt b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Clusterrefinementservice.kt new file mode 100644 index 0000000..c6f5f5c --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Clusterrefinementservice.kt @@ -0,0 +1,415 @@ +package com.placeholder.sherpai2.domain.clustering + +import android.util.Log +import com.placeholder.sherpai2.data.local.dao.FaceCacheDao +import com.placeholder.sherpai2.data.local.dao.UserFeedbackDao +import com.placeholder.sherpai2.data.local.entity.FeedbackType +import com.placeholder.sherpai2.data.local.entity.UserFeedbackEntity +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import javax.inject.Inject +import javax.inject.Singleton +import kotlin.math.sqrt + +/** + * ClusterRefinementService - Handle user feedback and cluster refinement + * + * PURPOSE: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * Close the feedback loop between user corrections and clustering + * + * WORKFLOW: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * 1. Clustering produces initial clusters + * 2. User reviews in ValidationPreview + * 3. User marks faces: ✅ Correct / ❌ Incorrect / ❓ Uncertain + * 4. If too many incorrect → Call refineCluster() + * 5. Re-cluster WITHOUT incorrect faces + * 6. Show updated validation preview + * 7. Repeat until user approves + * + * BENEFITS: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * - Prevents bad models from being created + * - Learns from user corrections + * - Iterative improvement + * - Ground truth data for future enhancements + */ +@Singleton +class ClusterRefinementService @Inject constructor( + private val faceCacheDao: FaceCacheDao, + private val userFeedbackDao: UserFeedbackDao, + private val qualityAnalyzer: ClusterQualityAnalyzer +) { + + companion object { + private const val TAG = "ClusterRefinement" + + // Thresholds for refinement decisions + private const val MIN_REJECTION_RATIO = 0.15f // 15% rejected → refine + private const val MIN_CONFIRMED_FACES = 6 // Need at least 6 good faces + private const val MAX_REFINEMENT_ITERATIONS = 3 // Prevent infinite loops + } + + /** + * Store user feedback for faces in a cluster + * + * @param cluster The cluster being reviewed + * @param feedbackMap Map of face index → feedback type + * @param originalConfidences Map of face index → original detection confidence + * @return Number of feedback items stored + */ + suspend fun storeFeedback( + cluster: FaceCluster, + feedbackMap: Map, + originalConfidences: Map = emptyMap() + ): Int = withContext(Dispatchers.IO) { + + val feedbackEntities = feedbackMap.map { (face, feedbackType) -> + UserFeedbackEntity.create( + imageId = face.imageId, + faceIndex = 0, // We don't track faceIndex in DetectedFaceWithEmbedding yet + clusterId = cluster.clusterId, + personId = null, // Not created yet + feedbackType = feedbackType, + originalConfidence = originalConfidences[face] ?: face.confidence + ) + } + + userFeedbackDao.insertAll(feedbackEntities) + + Log.d(TAG, "Stored ${feedbackEntities.size} feedback items for cluster ${cluster.clusterId}") + feedbackEntities.size + } + + /** + * Check if cluster needs refinement based on user feedback + * + * Criteria: + * - Too many rejected faces (> 15%) + * - Too few confirmed faces (< 6) + * - High rejection rate for cluster suggests mixed identities + * + * @return RefinementRecommendation with action and reason + */ + suspend fun shouldRefineCluster( + cluster: FaceCluster + ): RefinementRecommendation = withContext(Dispatchers.Default) { + + val feedback = withContext(Dispatchers.IO) { + userFeedbackDao.getFeedbackForCluster(cluster.clusterId) + } + + if (feedback.isEmpty()) { + return@withContext RefinementRecommendation( + shouldRefine = false, + reason = "No feedback provided yet" + ) + } + + val totalFeedback = feedback.size + val rejectedCount = feedback.count { it.getFeedbackType() == FeedbackType.REJECTED_MATCH } + val confirmedCount = feedback.count { it.getFeedbackType() == FeedbackType.CONFIRMED_MATCH } + val uncertainCount = feedback.count { it.getFeedbackType() == FeedbackType.UNCERTAIN } + + val rejectionRatio = rejectedCount.toFloat() / totalFeedback.toFloat() + + Log.d(TAG, "Cluster ${cluster.clusterId} feedback: " + + "$confirmedCount confirmed, $rejectedCount rejected, $uncertainCount uncertain") + + // Check 1: Too many rejections + if (rejectionRatio > MIN_REJECTION_RATIO) { + return@withContext RefinementRecommendation( + shouldRefine = true, + reason = "High rejection rate (${(rejectionRatio * 100).toInt()}%) suggests mixed identities", + confirmedCount = confirmedCount, + rejectedCount = rejectedCount, + uncertainCount = uncertainCount + ) + } + + // Check 2: Too few confirmed faces after removing rejected + val effectiveConfirmedCount = confirmedCount - rejectedCount + if (effectiveConfirmedCount < MIN_CONFIRMED_FACES) { + return@withContext RefinementRecommendation( + shouldRefine = true, + reason = "Only $effectiveConfirmedCount faces remain after removing rejected faces (need $MIN_CONFIRMED_FACES)", + confirmedCount = confirmedCount, + rejectedCount = rejectedCount, + uncertainCount = uncertainCount + ) + } + + // Cluster is good! + RefinementRecommendation( + shouldRefine = false, + reason = "Cluster quality acceptable: $confirmedCount confirmed, $rejectedCount rejected", + confirmedCount = confirmedCount, + rejectedCount = rejectedCount, + uncertainCount = uncertainCount + ) + } + + /** + * Refine cluster by removing rejected faces and re-clustering + * + * ALGORITHM: + * 1. Get all rejected faces from feedback + * 2. Remove those faces from cluster + * 3. Recalculate cluster centroid + * 4. Re-run quality analysis + * 5. Return refined cluster + * + * @param cluster Original cluster to refine + * @return Refined cluster without rejected faces + */ + suspend fun refineCluster( + cluster: FaceCluster, + iterationNumber: Int = 1 + ): ClusterRefinementResult = withContext(Dispatchers.Default) { + + Log.d(TAG, "Refining cluster ${cluster.clusterId} (iteration $iterationNumber)") + + // Guard against infinite refinement + if (iterationNumber > MAX_REFINEMENT_ITERATIONS) { + return@withContext ClusterRefinementResult( + success = false, + refinedCluster = null, + errorMessage = "Maximum refinement iterations reached. Cluster quality still poor.", + facesRemoved = 0, + facesRemaining = cluster.faces.size + ) + } + + // Get rejected faces + val feedback = withContext(Dispatchers.IO) { + userFeedbackDao.getRejectedFacesForCluster(cluster.clusterId) + } + + val rejectedImageIds = feedback.map { it.imageId }.toSet() + + if (rejectedImageIds.isEmpty()) { + return@withContext ClusterRefinementResult( + success = false, + refinedCluster = cluster, + errorMessage = "No rejected faces to remove", + facesRemoved = 0, + facesRemaining = cluster.faces.size + ) + } + + // Remove rejected faces + val cleanFaces = cluster.faces.filter { it.imageId !in rejectedImageIds } + + Log.d(TAG, "Removed ${rejectedImageIds.size} rejected faces, ${cleanFaces.size} remain") + + // Check if we have enough faces left + if (cleanFaces.size < MIN_CONFIRMED_FACES) { + return@withContext ClusterRefinementResult( + success = false, + refinedCluster = null, + errorMessage = "Too few faces remaining after removing rejected faces (${cleanFaces.size} < $MIN_CONFIRMED_FACES)", + facesRemoved = rejectedImageIds.size, + facesRemaining = cleanFaces.size + ) + } + + // Recalculate centroid + val newCentroid = calculateCentroid(cleanFaces.map { it.embedding }) + + // Select new representative faces + val newRepresentatives = selectRepresentativeFacesByCentroid(cleanFaces, newCentroid, count = 6) + + // Create refined cluster + val refinedCluster = FaceCluster( + clusterId = cluster.clusterId, + faces = cleanFaces, + representativeFaces = newRepresentatives, + photoCount = cleanFaces.map { it.imageId }.distinct().size, + averageConfidence = cleanFaces.map { it.confidence }.average().toFloat(), + estimatedAge = cluster.estimatedAge, // Keep same estimate + potentialSiblings = cluster.potentialSiblings // Keep same siblings + ) + + // Re-run quality analysis + val qualityResult = qualityAnalyzer.analyzeCluster(refinedCluster) + + Log.d(TAG, "Refined cluster quality: ${qualityResult.qualityTier} " + + "(${qualityResult.cleanFaceCount} clean faces)") + + ClusterRefinementResult( + success = true, + refinedCluster = refinedCluster, + qualityResult = qualityResult, + facesRemoved = rejectedImageIds.size, + facesRemaining = cleanFaces.size, + newQualityTier = qualityResult.qualityTier + ) + } + + /** + * Get feedback summary for cluster + * + * Returns human-readable summary like: + * "15 confirmed, 3 rejected, 2 uncertain" + */ + suspend fun getFeedbackSummary(clusterId: Int): FeedbackSummary = withContext(Dispatchers.IO) { + val feedback = userFeedbackDao.getFeedbackForCluster(clusterId) + + val confirmed = feedback.count { it.getFeedbackType() == FeedbackType.CONFIRMED_MATCH } + val rejected = feedback.count { it.getFeedbackType() == FeedbackType.REJECTED_MATCH } + val uncertain = feedback.count { it.getFeedbackType() == FeedbackType.UNCERTAIN } + val outliers = feedback.count { it.getFeedbackType() == FeedbackType.MARKED_OUTLIER } + + FeedbackSummary( + totalFeedback = feedback.size, + confirmedCount = confirmed, + rejectedCount = rejected, + uncertainCount = uncertain, + outlierCount = outliers, + rejectionRatio = if (feedback.isNotEmpty()) { + rejected.toFloat() / feedback.size.toFloat() + } else 0f + ) + } + + /** + * Filter cluster to only confirmed faces + * + * Use Case: User has reviewed cluster, now create model using ONLY confirmed faces + */ + suspend fun getConfirmedFaces(cluster: FaceCluster): List = + withContext(Dispatchers.Default) { + + val confirmedFeedback = withContext(Dispatchers.IO) { + userFeedbackDao.getConfirmedFacesForCluster(cluster.clusterId) + } + + val confirmedImageIds = confirmedFeedback.map { it.imageId }.toSet() + + // If no explicit confirmations, assume all non-rejected faces are OK + if (confirmedImageIds.isEmpty()) { + val rejectedFeedback = withContext(Dispatchers.IO) { + userFeedbackDao.getRejectedFacesForCluster(cluster.clusterId) + } + val rejectedImageIds = rejectedFeedback.map { it.imageId }.toSet() + + return@withContext cluster.faces.filter { it.imageId !in rejectedImageIds } + } + + // Return only explicitly confirmed faces + cluster.faces.filter { it.imageId in confirmedImageIds } + } + + /** + * Calculate centroid from embeddings + */ + private fun calculateCentroid(embeddings: List): FloatArray { + if (embeddings.isEmpty()) return FloatArray(0) + + val size = embeddings.first().size + val centroid = FloatArray(size) { 0f } + + embeddings.forEach { embedding -> + for (i in embedding.indices) { + centroid[i] += embedding[i] + } + } + + val count = embeddings.size.toFloat() + for (i in centroid.indices) { + centroid[i] /= count + } + + // Normalize + val norm = sqrt(centroid.map { it * it }.sum()) + return if (norm > 0) { + centroid.map { it / norm }.toFloatArray() + } else { + centroid + } + } + + /** + * Select representative faces closest to centroid + */ + private fun selectRepresentativeFacesByCentroid( + faces: List, + centroid: FloatArray, + count: Int + ): List { + if (faces.size <= count) return faces + + val facesWithDistance = faces.map { face -> + val similarity = cosineSimilarity(face.embedding, centroid) + val distance = 1 - similarity + face to distance + } + + return facesWithDistance + .sortedBy { it.second } + .take(count) + .map { it.first } + } + + /** + * Cosine similarity calculation + */ + private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float { + var dotProduct = 0f + var normA = 0f + var normB = 0f + + for (i in a.indices) { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + + return dotProduct / (sqrt(normA) * sqrt(normB)) + } +} + +/** + * Result of refinement analysis + */ +data class RefinementRecommendation( + val shouldRefine: Boolean, + val reason: String, + val confirmedCount: Int = 0, + val rejectedCount: Int = 0, + val uncertainCount: Int = 0 +) + +/** + * Result of cluster refinement + */ +data class ClusterRefinementResult( + val success: Boolean, + val refinedCluster: FaceCluster?, + val qualityResult: ClusterQualityResult? = null, + val errorMessage: String? = null, + val facesRemoved: Int, + val facesRemaining: Int, + val newQualityTier: ClusterQualityTier? = null +) + +/** + * Summary of user feedback for a cluster + */ +data class FeedbackSummary( + val totalFeedback: Int, + val confirmedCount: Int, + val rejectedCount: Int, + val uncertainCount: Int, + val outlierCount: Int, + val rejectionRatio: Float +) { + fun getDisplayText(): String { + val parts = mutableListOf() + if (confirmedCount > 0) parts.add("$confirmedCount confirmed") + if (rejectedCount > 0) parts.add("$rejectedCount rejected") + if (uncertainCount > 0) parts.add("$uncertainCount uncertain") + return parts.joinToString(", ") + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Faceclusteringservice.kt b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Faceclusteringservice.kt new file mode 100644 index 0000000..b561e3a --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Faceclusteringservice.kt @@ -0,0 +1,962 @@ +package com.placeholder.sherpai2.domain.clustering + +import android.content.Context +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import android.graphics.Rect +import android.net.Uri +import android.util.Log +import com.google.android.gms.tasks.Tasks +import com.google.mlkit.vision.common.InputImage +import com.google.mlkit.vision.face.FaceDetection +import com.google.mlkit.vision.face.FaceDetectorOptions +import com.placeholder.sherpai2.data.local.dao.FaceCacheDao +import com.placeholder.sherpai2.data.local.dao.ImageDao +import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity +import com.placeholder.sherpai2.data.local.entity.ImageEntity +import com.placeholder.sherpai2.ml.FaceNetModel +import com.placeholder.sherpai2.ui.discover.DiscoverySettings +import dagger.hilt.android.qualifiers.ApplicationContext +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.sync.Semaphore +import kotlinx.coroutines.withContext +import javax.inject.Inject +import javax.inject.Singleton +import kotlin.math.max +import kotlin.math.min +import kotlin.math.sqrt + +/** + * FaceClusteringService - FIXED to properly use metadata cache + * + * THE CRITICAL FIX: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * Path 2 now CORRECTLY checks for metadata cache WITHOUT requiring embeddings + * Uses countFacesWithoutEmbeddings() which counts faces that HAVE metadata + * but DON'T have embeddings yet + * + * 3-PATH STRATEGY (CORRECTED): + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * Path 1: Cached embeddings exist → Instant (< 2 sec) + * Path 2: Metadata cache exists → Generate embeddings for quality faces (~3 min) ← FIXED! + * Path 3: No cache → Full scan (~8 min) + */ +@Singleton +class FaceClusteringService @Inject constructor( + @ApplicationContext private val context: Context, + private val imageDao: ImageDao, + private val faceCacheDao: FaceCacheDao +) { + + private val semaphore = Semaphore(3) + + companion object { + private const val TAG = "FaceClustering" + private const val MAX_FACES_TO_CLUSTER = 2000 + + // Path selection thresholds + private const val MIN_CACHED_EMBEDDINGS = 20 // Path 1 + private const val MIN_QUALITY_METADATA = 50 // Path 2 + private const val MIN_STANDARD_FACES = 10 // Absolute minimum + + // IoU matching threshold + private const val IOU_THRESHOLD = 0.5f + } + + suspend fun discoverPeople( + strategy: ClusteringStrategy = ClusteringStrategy.PREMIUM_SOLO_ONLY, + maxFacesToCluster: Int = MAX_FACES_TO_CLUSTER, + onProgress: (Int, Int, String) -> Unit = { _, _, _ -> } + ): ClusteringResult = withContext(Dispatchers.Default) { + + val startTime = System.currentTimeMillis() + + Log.d(TAG, "════════════════════════════════════════") + Log.d(TAG, "CACHE-AWARE DISCOVERY STARTED") + Log.d(TAG, "════════════════════════════════════════") + + val result = when (strategy) { + ClusteringStrategy.PREMIUM_SOLO_ONLY -> { + clusterPremiumSoloFaces(maxFacesToCluster, onProgress) + } + ClusteringStrategy.STANDARD_SOLO_ONLY -> { + clusterStandardSoloFaces(maxFacesToCluster, onProgress) + } + ClusteringStrategy.TWO_PHASE -> { + clusterPremiumSoloFaces(maxFacesToCluster, onProgress) + } + ClusteringStrategy.LEGACY_ALL_FACES -> { + clusterAllFacesLegacy(maxFacesToCluster, onProgress) + } + } + + val elapsedTime = System.currentTimeMillis() - startTime + Log.d(TAG, "════════════════════════════════════════") + Log.d(TAG, "Discovery Complete!") + Log.d(TAG, "Clusters found: ${result.clusters.size}") + Log.d(TAG, "Time: ${elapsedTime / 1000}s") + Log.d(TAG, "════════════════════════════════════════") + + result.copy(processingTimeMs = elapsedTime) + } + + /** + * FIXED: 3-Path Selection with proper metadata checking + */ + private suspend fun clusterPremiumSoloFaces( + maxFaces: Int, + onProgress: (Int, Int, String) -> Unit + ): ClusteringResult = withContext(Dispatchers.Default) { + + onProgress(5, 100, "Checking cache...") + + // ═════════════════════════════════════════════════════════ + // PATH 1: Check for cached embeddings (INSTANT) + // ═════════════════════════════════════════════════════════ + Log.d(TAG, "Path 1: Checking for cached embeddings...") + + val embeddingCount = withContext(Dispatchers.IO) { + try { + faceCacheDao.countFacesWithEmbeddings(minQuality = 0.6f) + } catch (e: Exception) { + Log.w(TAG, "Error counting embeddings: ${e.message}") + 0 + } + } + + Log.d(TAG, "Found $embeddingCount faces with cached embeddings") + + if (embeddingCount >= MIN_CACHED_EMBEDDINGS) { + Log.d(TAG, "✅ PATH 1 SUCCESS: Using $embeddingCount cached embeddings") + + val cachedFaces = withContext(Dispatchers.IO) { + faceCacheDao.getAllQualityFaces( + minRatio = 0.03f, + minQuality = 0.6f, + limit = Int.MAX_VALUE + ) + } + + return@withContext clusterCachedEmbeddings(cachedFaces, maxFaces, onProgress) + } + + // ═════════════════════════════════════════════════════════ + // PATH 2: Check for metadata cache (FAST) + // ═════════════════════════════════════════════════════════ + Log.d(TAG, "Path 1 insufficient, trying Path 2...") + Log.d(TAG, "Path 2: Checking for quality metadata...") + + // THE CRITICAL FIX: Count faces WITH metadata but WITHOUT embeddings + val metadataCount = withContext(Dispatchers.IO) { + try { + faceCacheDao.countFacesWithoutEmbeddings(minQuality = 0.6f) + } catch (e: Exception) { + Log.w(TAG, "Error counting metadata: ${e.message}") + 0 + } + } + + Log.d(TAG, "Found $metadataCount faces in metadata cache (without embeddings)") + + if (metadataCount >= MIN_QUALITY_METADATA) { + Log.d(TAG, "✅ PATH 2 SUCCESS: Using metadata cache") + + val qualityMetadata = withContext(Dispatchers.IO) { + faceCacheDao.getQualityFacesWithoutEmbeddings( + minRatio = 0.03f, + minQuality = 0.6f, + limit = 5000 + ) + } + + Log.d(TAG, "Loaded ${qualityMetadata.size} quality face metadata entries") + return@withContext clusterWithQualityPrefiltering(qualityMetadata, maxFaces, onProgress) + } + + // ═════════════════════════════════════════════════════════ + // PATH 3: Full scan (SLOW, last resort) + // ═════════════════════════════════════════════════════════ + Log.w(TAG, "Path 2 insufficient, falling back to Path 3 (full scan)") + Log.w(TAG, "⚠️ PATH 3: Full library scan (this will take several minutes)") + Log.w(TAG, "Cache stats: $embeddingCount with embeddings, $metadataCount metadata only") + + onProgress(10, 100, "No cache found, performing full scan...") + return@withContext clusterAllFacesLegacy(maxFaces, onProgress) + } + + /** + * Path 1: Cluster using cached embeddings (INSTANT) + */ + private suspend fun clusterCachedEmbeddings( + cachedFaces: List, + maxFaces: Int, + onProgress: (Int, Int, String) -> Unit + ): ClusteringResult = withContext(Dispatchers.Default) { + + Log.d(TAG, "Converting ${cachedFaces.size} cached faces to clustering format...") + onProgress(30, 100, "Using ${cachedFaces.size} cached faces...") + + val allFaces = cachedFaces.mapNotNull { cached -> + val embedding = cached.getEmbedding() ?: return@mapNotNull null + + DetectedFaceWithEmbedding( + imageId = cached.imageId, + imageUri = "", + capturedAt = cached.detectedAt, + embedding = embedding, + boundingBox = cached.getBoundingBox(), + confidence = cached.confidence, + faceCount = 1, + imageWidth = cached.imageWidth, + imageHeight = cached.imageHeight + ) + } + + if (allFaces.isEmpty()) { + return@withContext ClusteringResult( + clusters = emptyList(), + totalFacesAnalyzed = 0, + processingTimeMs = 0, + errorMessage = "No valid cached embeddings found" + ) + } + + Log.d(TAG, "Clustering ${allFaces.size} cached faces...") + onProgress(50, 100, "Clustering ${allFaces.size} faces...") + + val rawClusters = performDBSCAN( + faces = allFaces.take(maxFaces), + epsilon = 0.22f, + minPoints = 3 + ) + + onProgress(75, 100, "Analyzing relationships...") + + val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters) + + onProgress(90, 100, "Finalizing clusters...") + + val clusters = rawClusters.mapIndexed { index, cluster -> + FaceCluster( + clusterId = index, + faces = cluster.faces, + representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, count = 6), + photoCount = cluster.faces.map { it.imageId }.distinct().size, + averageConfidence = cluster.faces.map { it.confidence }.average().toFloat(), + estimatedAge = estimateAge(cluster.faces), + potentialSiblings = findPotentialSiblings(cluster, rawClusters, coOccurrenceGraph) + ) + }.sortedByDescending { it.photoCount } + + onProgress(100, 100, "Complete!") + + ClusteringResult( + clusters = clusters, + totalFacesAnalyzed = allFaces.size, + processingTimeMs = 0, + strategy = ClusteringStrategy.PREMIUM_SOLO_ONLY + ) + } + + /** + * Path 2: CORRECTED to work with metadata cache + * + * Generates embeddings on-demand and saves them with IoU matching + */ + private suspend fun clusterWithQualityPrefiltering( + qualityFacesMetadata: List, + maxFaces: Int, + onProgress: (Int, Int, String) -> Unit + ): ClusteringResult = withContext(Dispatchers.Default) { + + Log.d(TAG, "Starting Path 2: Quality metadata pre-filtering") + Log.d(TAG, "Quality faces in metadata: ${qualityFacesMetadata.size}") + + onProgress(15, 100, "Pre-filtering complete...") + + // Extract unique imageIds from metadata + val imageIdsToProcess = qualityFacesMetadata + .map { it.imageId } + .distinct() + + Log.d(TAG, "Pre-filtered to ${imageIdsToProcess.size} images with quality faces") + + // Load only those specific images + val imagesToProcess = withContext(Dispatchers.IO) { + imageDao.getImagesByIds(imageIdsToProcess) + } + + Log.d(TAG, "Loading ${imagesToProcess.size} quality photos...") + onProgress(20, 100, "Generating embeddings for ${imagesToProcess.size} quality photos...") + + val faceNetModel = FaceNetModel(context) + val detector = FaceDetection.getClient( + FaceDetectorOptions.Builder() + .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) + .setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL) + .setMinFaceSize(0.15f) + .build() + ) + + try { + val allFaces = mutableListOf() + var iouMatchSuccesses = 0 + var iouMatchFailures = 0 + + coroutineScope { + val jobs = imagesToProcess.mapIndexed { index, image -> + async(Dispatchers.IO) { + semaphore.acquire() + try { + val bitmap = loadBitmapDownsampled( + Uri.parse(image.imageUri), + 768 + ) ?: return@async Triple(emptyList(), 0, 0) + + val inputImage = InputImage.fromBitmap(bitmap, 0) + val mlKitFaces = Tasks.await(detector.process(inputImage)) + + val imageWidth = bitmap.width + val imageHeight = bitmap.height + + // Get cached faces for THIS specific image + val cachedFacesForImage = qualityFacesMetadata.filter { + it.imageId == image.imageId + } + + var localSuccesses = 0 + var localFailures = 0 + + val facesForImage = mutableListOf() + + mlKitFaces.forEach { mlFace -> + val qualityCheck = FaceQualityFilter.validateForDiscovery( + face = mlFace, + imageWidth = imageWidth, + imageHeight = imageHeight + ) + + if (!qualityCheck.isValid) { + return@forEach + } + + try { + // Crop and generate embedding + val faceBitmap = Bitmap.createBitmap( + bitmap, + mlFace.boundingBox.left.coerceIn(0, bitmap.width - 1), + mlFace.boundingBox.top.coerceIn(0, bitmap.height - 1), + mlFace.boundingBox.width().coerceAtMost(bitmap.width - mlFace.boundingBox.left), + mlFace.boundingBox.height().coerceAtMost(bitmap.height - mlFace.boundingBox.top) + ) + + val embedding = faceNetModel.generateEmbedding(faceBitmap) + faceBitmap.recycle() + + // Add to results + facesForImage.add( + DetectedFaceWithEmbedding( + imageId = image.imageId, + imageUri = image.imageUri, + capturedAt = image.capturedAt, + embedding = embedding, + boundingBox = mlFace.boundingBox, + confidence = qualityCheck.confidenceScore, + faceCount = mlKitFaces.size, + imageWidth = imageWidth, + imageHeight = imageHeight + ) + ) + + // Save embedding to cache with IoU matching + val matched = matchAndSaveEmbedding( + imageId = image.imageId, + detectedBox = mlFace.boundingBox, + embedding = embedding, + cachedFaces = cachedFacesForImage + ) + + if (matched) localSuccesses++ else localFailures++ + + } catch (e: Exception) { + Log.w(TAG, "Failed to process face: ${e.message}") + } + } + + bitmap.recycle() + + // Update progress + if (index % 20 == 0) { + val progress = 20 + (index * 60 / imagesToProcess.size) + onProgress(progress, 100, "Processed $index/${imagesToProcess.size} photos...") + } + + Triple(facesForImage, localSuccesses, localFailures) + } finally { + semaphore.release() + } + } + } + + val results = jobs.awaitAll() + results.forEach { (faces, successes, failures) -> + allFaces.addAll(faces) + iouMatchSuccesses += successes + iouMatchFailures += failures + } + } + + Log.d(TAG, "IoU Matching Results:") + Log.d(TAG, " Successful matches: $iouMatchSuccesses") + Log.d(TAG, " Failed matches: $iouMatchFailures") + val successRate = if (iouMatchSuccesses + iouMatchFailures > 0) { + (iouMatchSuccesses.toFloat() / (iouMatchSuccesses + iouMatchFailures) * 100).toInt() + } else 0 + Log.d(TAG, " Success rate: $successRate%") + Log.d(TAG, "✅ Embeddings saved to cache with IoU matching") + + if (allFaces.isEmpty()) { + return@withContext ClusteringResult( + clusters = emptyList(), + totalFacesAnalyzed = 0, + processingTimeMs = 0, + errorMessage = "No faces detected with sufficient quality" + ) + } + + // Cluster + onProgress(80, 100, "Clustering ${allFaces.size} faces...") + + val rawClusters = performDBSCAN(allFaces.take(maxFaces), 0.22f, 3) + val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters) + + onProgress(90, 100, "Finalizing clusters...") + + val clusters = rawClusters.mapIndexed { index, cluster -> + FaceCluster( + clusterId = index, + faces = cluster.faces, + representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, count = 6), + photoCount = cluster.faces.map { it.imageId }.distinct().size, + averageConfidence = cluster.faces.map { it.confidence }.average().toFloat(), + estimatedAge = estimateAge(cluster.faces), + potentialSiblings = findPotentialSiblings(cluster, rawClusters, coOccurrenceGraph) + ) + }.sortedByDescending { it.photoCount } + + onProgress(100, 100, "Complete!") + + ClusteringResult( + clusters = clusters, + totalFacesAnalyzed = allFaces.size, + processingTimeMs = 0, + strategy = ClusteringStrategy.PREMIUM_SOLO_ONLY + ) + } finally { + detector.close() + } + } + + /** + * IoU matching and saving - handles non-deterministic ML Kit order + */ + private suspend fun matchAndSaveEmbedding( + imageId: String, + detectedBox: Rect, + embedding: FloatArray, + cachedFaces: List + ): Boolean { + if (cachedFaces.isEmpty()) { + return false + } + + // Find best matching cached face by IoU + var bestMatch: FaceCacheEntity? = null + var bestIoU = 0f + + cachedFaces.forEach { cached -> + val iou = calculateIoU(detectedBox, cached.getBoundingBox()) + if (iou > bestIoU) { + bestIoU = iou + bestMatch = cached + } + } + + // Save if IoU meets threshold + if (bestMatch != null && bestIoU >= IOU_THRESHOLD) { + try { + withContext(Dispatchers.IO) { + val updated = bestMatch!!.copy( + embedding = embedding.joinToString(",") + ) + faceCacheDao.update(updated) + } + return true + } catch (e: Exception) { + Log.e(TAG, "Failed to save embedding: ${e.message}") + return false + } + } + + return false + } + + /** + * Calculate IoU between two bounding boxes + */ + private fun calculateIoU(rect1: Rect, rect2: Rect): Float { + val intersectionLeft = max(rect1.left, rect2.left) + val intersectionTop = max(rect1.top, rect2.top) + val intersectionRight = min(rect1.right, rect2.right) + val intersectionBottom = min(rect1.bottom, rect2.bottom) + + if (intersectionLeft >= intersectionRight || intersectionTop >= intersectionBottom) { + return 0f + } + + val intersectionArea = (intersectionRight - intersectionLeft) * (intersectionBottom - intersectionTop) + val area1 = rect1.width() * rect1.height() + val area2 = rect2.width() * rect2.height() + val unionArea = area1 + area2 - intersectionArea + + return if (unionArea > 0) { + intersectionArea.toFloat() / unionArea.toFloat() + } else { + 0f + } + } + + private suspend fun clusterStandardSoloFaces( + maxFaces: Int, + onProgress: (Int, Int, String) -> Unit + ): ClusteringResult = clusterPremiumSoloFaces(maxFaces, onProgress) + + /** + * Path 3: Legacy full scan (fallback only) + */ + private suspend fun clusterAllFacesLegacy( + maxFaces: Int, + onProgress: (Int, Int, String) -> Unit + ): ClusteringResult = withContext(Dispatchers.Default) { + + Log.w(TAG, "⚠️ Running LEGACY full scan") + + onProgress(10, 100, "Loading all images...") + + val allImages = withContext(Dispatchers.IO) { + imageDao.getAllImages() + } + + Log.d(TAG, "Processing ${allImages.size} images...") + onProgress(20, 100, "Detecting faces in ${allImages.size} photos...") + + val faceNetModel = FaceNetModel(context) + val detector = FaceDetection.getClient( + FaceDetectorOptions.Builder() + .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) + .setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL) + .setMinFaceSize(0.15f) + .build() + ) + + try { + val allFaces = mutableListOf() + + coroutineScope { + val jobs = allImages.mapIndexed { index, image -> + async(Dispatchers.IO) { + semaphore.acquire() + try { + val bitmap = loadBitmapDownsampled( + Uri.parse(image.imageUri), + 768 + ) ?: return@async emptyList() + + val inputImage = InputImage.fromBitmap(bitmap, 0) + val faces = Tasks.await(detector.process(inputImage)) + + val imageWidth = bitmap.width + val imageHeight = bitmap.height + + val faceEmbeddings = faces.mapNotNull { face -> + val qualityCheck = FaceQualityFilter.validateForDiscovery( + face = face, + imageWidth = imageWidth, + imageHeight = imageHeight + ) + + if (!qualityCheck.isValid) return@mapNotNull null + + try { + val faceBitmap = Bitmap.createBitmap( + bitmap, + face.boundingBox.left.coerceIn(0, bitmap.width - 1), + face.boundingBox.top.coerceIn(0, bitmap.height - 1), + face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left), + face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top) + ) + + val embedding = faceNetModel.generateEmbedding(faceBitmap) + faceBitmap.recycle() + + DetectedFaceWithEmbedding( + imageId = image.imageId, + imageUri = image.imageUri, + capturedAt = image.capturedAt, + embedding = embedding, + boundingBox = face.boundingBox, + confidence = qualityCheck.confidenceScore, + faceCount = faces.size, + imageWidth = imageWidth, + imageHeight = imageHeight + ) + } catch (e: Exception) { + null + } + } + + bitmap.recycle() + + if (index % 20 == 0) { + val progress = 20 + (index * 60 / allImages.size) + onProgress(progress, 100, "Processed $index/${allImages.size} photos...") + } + + faceEmbeddings + } finally { + semaphore.release() + } + } + } + + jobs.awaitAll().flatten().forEach { allFaces.add(it) } + } + + if (allFaces.isEmpty()) { + return@withContext ClusteringResult( + clusters = emptyList(), + totalFacesAnalyzed = 0, + processingTimeMs = 0, + errorMessage = "No faces detected" + ) + } + + onProgress(80, 100, "Clustering ${allFaces.size} faces...") + + val rawClusters = performDBSCAN(allFaces.take(maxFaces), 0.22f, 3) + val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters) + + onProgress(90, 100, "Finalizing clusters...") + + val clusters = rawClusters.mapIndexed { index, cluster -> + FaceCluster( + clusterId = index, + faces = cluster.faces, + representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, count = 6), + photoCount = cluster.faces.map { it.imageId }.distinct().size, + averageConfidence = cluster.faces.map { it.confidence }.average().toFloat(), + estimatedAge = estimateAge(cluster.faces), + potentialSiblings = findPotentialSiblings(cluster, rawClusters, coOccurrenceGraph) + ) + }.sortedByDescending { it.photoCount } + + onProgress(100, 100, "Complete!") + + ClusteringResult( + clusters = clusters, + totalFacesAnalyzed = allFaces.size, + processingTimeMs = 0, + strategy = ClusteringStrategy.LEGACY_ALL_FACES + ) + } finally { + detector.close() + } + } + + // REPLACE the discoverPeopleWithSettings method (lines 679-716) with this: + + suspend fun discoverPeopleWithSettings( + settings: DiscoverySettings, + onProgress: (Int, Int, String) -> Unit = { _, _, _ -> } + ): ClusteringResult = withContext(Dispatchers.Default) { + + Log.d(TAG, "════════════════════════════════════════") + Log.d(TAG, "🎛️ DISCOVERY WITH CUSTOM SETTINGS") + Log.d(TAG, "════════════════════════════════════════") + Log.d(TAG, "Settings received:") + Log.d(TAG, " • minFaceSize: ${settings.minFaceSize} (${(settings.minFaceSize * 100).toInt()}%)") + Log.d(TAG, " • minQuality: ${settings.minQuality} (${(settings.minQuality * 100).toInt()}%)") + Log.d(TAG, " • epsilon: ${settings.epsilon}") + Log.d(TAG, "════════════════════════════════════════") + + // Get quality faces using settings + val qualityMetadata = withContext(Dispatchers.IO) { + faceCacheDao.getQualityFacesWithoutEmbeddings( + minRatio = settings.minFaceSize, + minQuality = settings.minQuality, + limit = 5000 + ) + } + + Log.d(TAG, "Found ${qualityMetadata.size} faces matching quality settings") + Log.d(TAG, " • Query used: minRatio=${settings.minFaceSize}, minQuality=${settings.minQuality}") + + // Adjust threshold based on library size + val minRequired = if (qualityMetadata.size < 50) 30 else 50 + + Log.d(TAG, "Path selection:") + Log.d(TAG, " • Faces available: ${qualityMetadata.size}") + Log.d(TAG, " • Minimum required: $minRequired") + + if (qualityMetadata.size >= minRequired) { + Log.d(TAG, "✅ Using Path 2 (quality pre-filtering)") + Log.d(TAG, "════════════════════════════════════════") + + // Use Path 2 (quality pre-filtering) + return@withContext clusterWithQualityPrefiltering( + qualityFacesMetadata = qualityMetadata, + maxFaces = MAX_FACES_TO_CLUSTER, + onProgress = onProgress + ) + } else { + Log.d(TAG, "⚠️ Using fallback path (standard discovery)") + Log.d(TAG, " • Reason: ${qualityMetadata.size} < $minRequired") + Log.d(TAG, "════════════════════════════════════════") + + // Fallback to regular discovery (no Path 3, use existing methods) + Log.w(TAG, "Insufficient metadata (${qualityMetadata.size} < $minRequired), using standard discovery") + + // Use existing discoverPeople with appropriate strategy + val strategy = if (settings.minQuality >= 0.7f) { + ClusteringStrategy.PREMIUM_SOLO_ONLY + } else { + ClusteringStrategy.STANDARD_SOLO_ONLY + } + + return@withContext discoverPeople( + strategy = strategy, + maxFacesToCluster = MAX_FACES_TO_CLUSTER, + onProgress = onProgress + ) + } + } + // Clustering algorithms (unchanged) + private fun performDBSCAN(faces: List, epsilon: Float, minPoints: Int): List { + val visited = mutableSetOf() + val clusters = mutableListOf() + var clusterId = 0 + + for (i in faces.indices) { + if (i in visited) continue + val neighbors = findNeighbors(i, faces, epsilon) + if (neighbors.size < minPoints) { + visited.add(i) + continue + } + + val cluster = mutableListOf() + val queue = ArrayDeque(listOf(i)) + + while (queue.isNotEmpty()) { + val pointIdx = queue.removeFirst() + if (pointIdx in visited) continue + + visited.add(pointIdx) + cluster.add(faces[pointIdx]) + + val pointNeighbors = findNeighbors(pointIdx, faces, epsilon) + if (pointNeighbors.size >= minPoints) { + queue.addAll(pointNeighbors.filter { it !in visited }) + } + } + + if (cluster.size >= minPoints) { + clusters.add(RawCluster(clusterId++, cluster)) + } + } + + return clusters + } + + private fun findNeighbors(pointIdx: Int, faces: List, epsilon: Float): List { + val point = faces[pointIdx] + return faces.indices.filter { i -> + if (i == pointIdx) return@filter false + val otherFace = faces[i] + val similarity = cosineSimilarity(point.embedding, otherFace.embedding) + val appearTogether = point.imageId == otherFace.imageId + val effectiveEpsilon = if (appearTogether) epsilon * 0.7f else epsilon + similarity > (1 - effectiveEpsilon) + } + } + + private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float { + var dotProduct = 0f + var normA = 0f + var normB = 0f + for (i in a.indices) { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + return dotProduct / (sqrt(normA) * sqrt(normB)) + } + + private fun buildCoOccurrenceGraph(clusters: List): Map> { + val graph = mutableMapOf>() + for (i in clusters.indices) { + graph[i] = mutableMapOf() + val imageIds = clusters[i].faces.map { it.imageId }.toSet() + for (j in clusters.indices) { + if (i == j) continue + val sharedImages = clusters[j].faces.count { it.imageId in imageIds } + if (sharedImages > 0) { + graph[i]!![j] = sharedImages + } + } + } + return graph + } + + private fun findPotentialSiblings(cluster: RawCluster, allClusters: List, coOccurrenceGraph: Map>): List { + val clusterIdx = allClusters.indexOf(cluster) + if (clusterIdx == -1) return emptyList() + return coOccurrenceGraph[clusterIdx] + ?.filter { (_, count) -> count >= 5 } + ?.keys + ?.toList() + ?: emptyList() + } + + fun selectRepresentativeFacesByCentroid(faces: List, count: Int): List { + if (faces.size <= count) return faces + val centroid = calculateCentroid(faces.map { it.embedding }) + val facesWithDistance = faces.map { face -> + val distance = 1 - cosineSimilarity(face.embedding, centroid) + face to distance + } + val sortedByProximity = facesWithDistance.sortedBy { it.second } + val representatives = mutableListOf() + representatives.add(sortedByProximity.first().first) + val remainingFaces = sortedByProximity.drop(1).take(count * 3) + val sortedByTime = remainingFaces.map { it.first }.sortedBy { it.capturedAt } + if (sortedByTime.isNotEmpty()) { + val step = sortedByTime.size / (count - 1).coerceAtLeast(1) + for (i in 0 until (count - 1)) { + val index = (i * step).coerceAtMost(sortedByTime.size - 1) + representatives.add(sortedByTime[index]) + } + } + return representatives.take(count) + } + + private fun calculateCentroid(embeddings: List): FloatArray { + if (embeddings.isEmpty()) return FloatArray(0) + val size = embeddings.first().size + val centroid = FloatArray(size) { 0f } + embeddings.forEach { embedding -> + for (i in embedding.indices) { + centroid[i] += embedding[i] + } + } + val count = embeddings.size.toFloat() + for (i in centroid.indices) { + centroid[i] /= count + } + val norm = sqrt(centroid.map { it * it }.sum()) + return if (norm > 0) { + centroid.map { it / norm }.toFloatArray() + } else { + centroid + } + } + + private fun estimateAge(faces: List): AgeEstimate { + val timestamps = faces.map { it.capturedAt }.sorted() + if (timestamps.isEmpty() || timestamps.last() == 0L) return AgeEstimate.UNKNOWN + val span = timestamps.last() - timestamps.first() + val spanYears = span / (365.25 * 24 * 60 * 60 * 1000) + return if (spanYears > 3.0) AgeEstimate.CHILD else AgeEstimate.UNKNOWN + } + + private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): Bitmap? { + return try { + val opts = BitmapFactory.Options().apply { inJustDecodeBounds = true } + context.contentResolver.openInputStream(uri)?.use { + BitmapFactory.decodeStream(it, null, opts) + } + var sample = 1 + while (opts.outWidth / sample > maxDim || opts.outHeight / sample > maxDim) { + sample *= 2 + } + val finalOpts = BitmapFactory.Options().apply { + inSampleSize = sample + inPreferredConfig = Bitmap.Config.RGB_565 + } + context.contentResolver.openInputStream(uri)?.use { + BitmapFactory.decodeStream(it, null, finalOpts) + } + } catch (e: Exception) { + null + } + } +} + +enum class ClusteringStrategy { + PREMIUM_SOLO_ONLY, + STANDARD_SOLO_ONLY, + TWO_PHASE, + LEGACY_ALL_FACES +} + +data class DetectedFaceWithEmbedding( + val imageId: String, + val imageUri: String, + val capturedAt: Long, + val embedding: FloatArray, + val boundingBox: android.graphics.Rect, + val confidence: Float, + val faceCount: Int = 1, + val imageWidth: Int = 0, + val imageHeight: Int = 0 +) { + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false + other as DetectedFaceWithEmbedding + return imageId == other.imageId + } + override fun hashCode(): Int = imageId.hashCode() +} + +data class RawCluster( + val clusterId: Int, + val faces: List +) + +data class FaceCluster( + val clusterId: Int, + val faces: List, + val representativeFaces: List, + val photoCount: Int, + val averageConfidence: Float, + val estimatedAge: AgeEstimate, + val potentialSiblings: List +) + +data class ClusteringResult( + val clusters: List, + val totalFacesAnalyzed: Int, + val processingTimeMs: Long, + val errorMessage: String? = null, + val strategy: ClusteringStrategy = ClusteringStrategy.PREMIUM_SOLO_ONLY +) + +enum class AgeEstimate { + CHILD, + ADULT, + UNKNOWN +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Facequalityfilter.kt b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Facequalityfilter.kt new file mode 100644 index 0000000..63b5abf --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Facequalityfilter.kt @@ -0,0 +1,140 @@ +package com.placeholder.sherpai2.domain.clustering + +import com.google.mlkit.vision.face.Face +import com.google.mlkit.vision.face.FaceLandmark +import kotlin.math.abs +import kotlin.math.pow +import kotlin.math.sqrt + +/** + * FaceQualityFilter - Quality filtering for face detection + * + * PURPOSE: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * Two modes with different strictness: + * 1. Discovery: RELAXED (we want to find people, be permissive) + * 2. Scanning: MINIMAL (only reject obvious garbage) + * + * FILTERS OUT: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * ✅ Ghost faces (no eyes detected) + * ✅ Tiny faces (< 10% of image) + * ✅ Extreme angles (> 45°) + * ⚠️ Side profiles (both eyes required) + * + * ALLOWS: + * ✅ Moderate angles (up to 45°) + * ✅ Faces without tracking ID (not reliable) + * ✅ Faces without nose (some angles don't show nose) + */ +object FaceQualityFilter { + + /** + * Validate face for Discovery/Clustering + * + * RELAXED thresholds - we want to find people, not reject everything + */ + fun validateForDiscovery( + face: Face, + imageWidth: Int, + imageHeight: Int + ): FaceQualityValidation { + val issues = mutableListOf() + + // ===== CHECK 1: Eye Detection (CRITICAL) ===== + val leftEye = face.getLandmark(FaceLandmark.LEFT_EYE) + val rightEye = face.getLandmark(FaceLandmark.RIGHT_EYE) + + if (leftEye == null || rightEye == null) { + issues.add("Missing eye landmarks") + return FaceQualityValidation(false, issues, 0f) + } + + // ===== CHECK 2: Head Pose (RELAXED - 45°) ===== + val headEulerAngleY = face.headEulerAngleY + val headEulerAngleZ = face.headEulerAngleZ + val headEulerAngleX = face.headEulerAngleX + + if (abs(headEulerAngleY) > 45f) { + issues.add("Head turned too far") + } + + if (abs(headEulerAngleZ) > 45f) { + issues.add("Head tilted too much") + } + + if (abs(headEulerAngleX) > 40f) { + issues.add("Head angle too extreme") + } + + // ===== CHECK 3: Face Size (RELAXED - 10%) ===== + val faceWidthRatio = face.boundingBox.width() / imageWidth.toFloat() + val faceHeightRatio = face.boundingBox.height() / imageHeight.toFloat() + + if (faceWidthRatio < 0.10f) { + issues.add("Face too small") + } + + if (faceHeightRatio < 0.10f) { + issues.add("Face too small") + } + + // ===== CHECK 4: Eye Distance (OPTIONAL) ===== + if (leftEye != null && rightEye != null) { + val eyeDistance = sqrt( + (rightEye.position.x - leftEye.position.x).toDouble().pow(2.0) + + (rightEye.position.y - leftEye.position.y).toDouble().pow(2.0) + ).toFloat() + + val eyeDistanceRatio = eyeDistance / face.boundingBox.width() + if (eyeDistanceRatio < 0.15f || eyeDistanceRatio > 0.65f) { + issues.add("Abnormal eye spacing") + } + } + + // ===== CONFIDENCE SCORE ===== + val poseScore = 1f - (abs(headEulerAngleY) + abs(headEulerAngleZ) + abs(headEulerAngleX)) / 270f + val sizeScore = (faceWidthRatio + faceHeightRatio) / 2f + val nose = face.getLandmark(FaceLandmark.NOSE_BASE) + val landmarkScore = if (nose != null) 1f else 0.8f + + val confidenceScore = (poseScore * 0.4f + sizeScore * 0.3f + landmarkScore * 0.3f).coerceIn(0f, 1f) + + // ===== VERDICT (RELAXED - 0.5 threshold) ===== + val isValid = issues.isEmpty() && confidenceScore >= 0.5f + + return FaceQualityValidation(isValid, issues, confidenceScore) + } + + /** + * Quick check for scanning phase (permissive) + */ + fun validateForScanning( + face: Face, + imageWidth: Int, + imageHeight: Int + ): Boolean { + val leftEye = face.getLandmark(FaceLandmark.LEFT_EYE) + val rightEye = face.getLandmark(FaceLandmark.RIGHT_EYE) + + if (leftEye == null && rightEye == null) { + return false + } + + val faceWidthRatio = face.boundingBox.width() / imageWidth.toFloat() + if (faceWidthRatio < 0.08f) { + return false + } + + return true + } +} + +data class FaceQualityValidation( + val isValid: Boolean, + val issues: List, + val confidenceScore: Float +) { + val passesStrictValidation: Boolean get() = isValid && confidenceScore >= 0.7f + val passesModerateValidation: Boolean get() = isValid && confidenceScore >= 0.5f +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Temporalclusteringservice.kt b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Temporalclusteringservice.kt new file mode 100644 index 0000000..8561e9d --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Temporalclusteringservice.kt @@ -0,0 +1,597 @@ +package com.placeholder.sherpai2.domain.clustering + +import android.content.Context +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import android.net.Uri +import android.util.Log +import com.google.android.gms.tasks.Tasks +import com.google.mlkit.vision.common.InputImage +import com.google.mlkit.vision.face.FaceDetection +import com.google.mlkit.vision.face.FaceDetectorOptions +import com.placeholder.sherpai2.data.local.dao.FaceCacheDao +import com.placeholder.sherpai2.data.local.dao.ImageDao +import com.placeholder.sherpai2.ml.FaceNetModel +import dagger.hilt.android.qualifiers.ApplicationContext +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.sync.Semaphore +import kotlinx.coroutines.withContext +import java.util.Calendar +import javax.inject.Inject +import javax.inject.Singleton +import kotlin.math.sqrt +import kotlin.random.Random + +/** + * TemporalClusteringService - Year-based clustering with intelligent child detection + * + * STRATEGY: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * 1. Process ALL photos (no limits) + * 2. Apply strict quality filter (FaceQualityFilter) + * 3. Group faces by YEAR + * 4. Cluster within each year + * 5. Link clusters across years (same person) + * 6. Detect children (changing appearance over years) + * 7. Generate tags: "Emma_2020", "Emma_Age_2", "Brad_Pitt" + * + * CHILD DETECTION: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * A person is a CHILD if: + * - Appears across 3+ years + * - Face embeddings change significantly between years (>0.20 distance) + * - Consistent presence (not just random appearances) + * + * OUTPUT: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * Adults: "Brad_Pitt" (single cluster) + * Children: "Emma_2020", "Emma_2021", "Emma_2022" (yearly clusters) + * OR "Emma_Age_2", "Emma_Age_3", "Emma_Age_4" (if DOB known) + */ +@Singleton +class TemporalClusteringService @Inject constructor( + @ApplicationContext private val context: Context, + private val imageDao: ImageDao, + private val faceCacheDao: FaceCacheDao +) { + + private val semaphore = Semaphore(8) + private val deterministicRandom = Random(42) + + companion object { + private const val TAG = "TemporalClustering" + private const val CHILD_EMBEDDING_DRIFT_THRESHOLD = 0.20f // Significant change + private const val CHILD_MIN_YEARS = 3 // Must span 3+ years + private const val ADULT_SIMILARITY_THRESHOLD = 0.80f // 80% similar across years + private const val CHILD_SIMILARITY_THRESHOLD = 0.70f // 70% similar (more lenient) + } + + /** + * Discover people with year-based clustering + * + * @return List of AnnotatedCluster (year-specific clusters with metadata) + */ + suspend fun discoverPeopleByYear( + onProgress: (Int, Int, String) -> Unit = { _, _, _ -> } + ): TemporalClusteringResult = withContext(Dispatchers.Default) { + + val startTime = System.currentTimeMillis() + + onProgress(5, 100, "Loading all photos...") + + // STEP 1: Load ALL images (no limit) + val allImages = withContext(Dispatchers.IO) { + imageDao.getAllImages() + } + + if (allImages.isEmpty()) { + return@withContext TemporalClusteringResult( + clusters = emptyList(), + totalPhotosProcessed = 0, + totalFacesDetected = 0, + processingTimeMs = 0, + errorMessage = "No photos in library" + ) + } + + Log.d(TAG, "Processing ${allImages.size} photos (no limit)") + + onProgress(10, 100, "Detecting high-quality faces...") + + // STEP 2: Detect faces with STRICT quality filtering + val faceNetModel = FaceNetModel(context) + val detector = FaceDetection.getClient( + FaceDetectorOptions.Builder() + .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) + .setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL) + .setMinFaceSize(0.15f) + .build() + ) + + try { + val allFaces = mutableListOf() + + coroutineScope { + val jobs = allImages.mapIndexed { index, image -> + async(Dispatchers.IO) { + semaphore.acquire() + try { + val bitmap = loadBitmapDownsampled(Uri.parse(image.imageUri), 768) + ?: return@async emptyList() + + val inputImage = InputImage.fromBitmap(bitmap, 0) + val faces = Tasks.await(detector.process(inputImage)) + + val imageWidth = bitmap.width + val imageHeight = bitmap.height + + val validFaces = faces.mapNotNull { face -> + // Apply STRICT quality filter + val qualityCheck = FaceQualityFilter.validateForDiscovery( + face = face, + imageWidth = imageWidth, + imageHeight = imageHeight + ) + + if (!qualityCheck.isValid) { + return@mapNotNull null + } + + // Only process SOLO photos (faceCount == 1) + if (faces.size != 1) { + return@mapNotNull null + } + + try { + val faceBitmap = Bitmap.createBitmap( + bitmap, + face.boundingBox.left.coerceIn(0, bitmap.width - 1), + face.boundingBox.top.coerceIn(0, bitmap.height - 1), + face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left), + face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top) + ) + + val embedding = faceNetModel.generateEmbedding(faceBitmap) + faceBitmap.recycle() + + DetectedFaceWithEmbedding( + imageId = image.imageId, + imageUri = image.imageUri, + capturedAt = image.capturedAt, + embedding = embedding, + boundingBox = face.boundingBox, + confidence = qualityCheck.confidenceScore, + faceCount = 1, + imageWidth = imageWidth, + imageHeight = imageHeight + ) + } catch (e: Exception) { + null + } + } + + bitmap.recycle() + + if (index % 50 == 0) { + val progress = 10 + (index * 40 / allImages.size) + onProgress(progress, 100, "Processed $index/${allImages.size} photos...") + } + + validFaces + } finally { + semaphore.release() + } + } + } + + jobs.awaitAll().flatten().forEach { allFaces.add(it) } + } + + Log.d(TAG, "Detected ${allFaces.size} high-quality solo faces") + + if (allFaces.isEmpty()) { + return@withContext TemporalClusteringResult( + clusters = emptyList(), + totalPhotosProcessed = allImages.size, + totalFacesDetected = 0, + processingTimeMs = System.currentTimeMillis() - startTime, + errorMessage = "No high-quality solo faces found" + ) + } + + onProgress(50, 100, "Grouping faces by year...") + + // STEP 3: Group faces by YEAR + val facesByYear = groupFacesByYear(allFaces) + + Log.d(TAG, "Faces grouped into ${facesByYear.size} years") + + onProgress(60, 100, "Clustering within each year...") + + // STEP 4: Cluster within each year + val yearClusters = mutableListOf() + + facesByYear.forEach { (year, faces) -> + Log.d(TAG, "Clustering $year: ${faces.size} faces") + + val rawClusters = performDBSCAN( + faces = faces, + epsilon = 0.24f, + minPoints = 3 + ) + + rawClusters.forEach { rawCluster -> + yearClusters.add( + YearCluster( + year = year, + faces = rawCluster.faces, + centroid = calculateCentroid(rawCluster.faces.map { it.embedding }) + ) + ) + } + } + + Log.d(TAG, "Created ${yearClusters.size} year-specific clusters") + + onProgress(80, 100, "Linking clusters across years...") + + // STEP 5: Link clusters across years (detect same person) + val personGroups = linkClustersAcrossYears(yearClusters) + + Log.d(TAG, "Identified ${personGroups.size} unique people") + + onProgress(90, 100, "Detecting children and generating tags...") + + // STEP 6: Detect children and generate final clusters + val annotatedClusters = personGroups.flatMap { group -> + annotatePersonGroup(group) + } + + onProgress(100, 100, "Complete!") + + TemporalClusteringResult( + clusters = annotatedClusters.sortedByDescending { it.cluster.faces.size }, + totalPhotosProcessed = allImages.size, + totalFacesDetected = allFaces.size, + processingTimeMs = System.currentTimeMillis() - startTime + ) + + } finally { + faceNetModel.close() + detector.close() + } + } + + /** + * Group faces by year of capture + */ + private fun groupFacesByYear(faces: List): Map> { + return faces.groupBy { face -> + val calendar = Calendar.getInstance() + calendar.timeInMillis = face.capturedAt + calendar.get(Calendar.YEAR).toString() + } + } + + /** + * Link year clusters that belong to the same person + */ + private fun linkClustersAcrossYears(yearClusters: List): List { + val sortedClusters = yearClusters.sortedBy { it.year } + val visited = mutableSetOf() + val personGroups = mutableListOf() + + for (cluster in sortedClusters) { + if (cluster in visited) continue + + val group = mutableListOf() + group.add(cluster) + visited.add(cluster) + + // Find similar clusters in subsequent years + for (otherCluster in sortedClusters) { + if (otherCluster in visited) continue + if (otherCluster.year <= cluster.year) continue + + val similarity = cosineSimilarity(cluster.centroid, otherCluster.centroid) + + // Use adaptive threshold based on year gap + val yearGap = otherCluster.year.toInt() - cluster.year.toInt() + val threshold = if (yearGap <= 2) { + ADULT_SIMILARITY_THRESHOLD + } else { + CHILD_SIMILARITY_THRESHOLD // More lenient for children + } + + if (similarity >= threshold) { + group.add(otherCluster) + visited.add(otherCluster) + } + } + + personGroups.add(PersonGroup(clusters = group)) + } + + return personGroups + } + + /** + * Annotate person group (detect if child, generate tags) + */ + private fun annotatePersonGroup(group: PersonGroup): List { + val sortedClusters = group.clusters.sortedBy { it.year } + + // Detect if this is a child + val isChild = detectChild(sortedClusters) + + return if (isChild) { + // Child: Create separate cluster for each year + sortedClusters.map { yearCluster -> + AnnotatedCluster( + cluster = FaceCluster( + clusterId = 0, + faces = yearCluster.faces, + representativeFaces = selectRepresentativeFaces(yearCluster.faces, 6), + photoCount = yearCluster.faces.size, + averageConfidence = yearCluster.faces.map { it.confidence }.average().toFloat(), + estimatedAge = AgeEstimate.CHILD, + potentialSiblings = emptyList() + ), + year = yearCluster.year, + isChild = true, + suggestedName = null, + suggestedAge = estimateAgeInYear(yearCluster.year, sortedClusters) + ) + } + } else { + // Adult: Single cluster combining all years + val allFaces = sortedClusters.flatMap { it.faces } + listOf( + AnnotatedCluster( + cluster = FaceCluster( + clusterId = 0, + faces = allFaces, + representativeFaces = selectRepresentativeFaces(allFaces, 6), + photoCount = allFaces.size, + averageConfidence = allFaces.map { it.confidence }.average().toFloat(), + estimatedAge = AgeEstimate.ADULT, + potentialSiblings = emptyList() + ), + year = "All Years", + isChild = false, + suggestedName = null, + suggestedAge = null + ) + ) + } + } + + /** + * Detect if person group represents a child + */ + private fun detectChild(clusters: List): Boolean { + if (clusters.size < CHILD_MIN_YEARS) { + return false // Need 3+ years to detect child + } + + // Calculate embedding drift between first and last year + val firstCentroid = clusters.first().centroid + val lastCentroid = clusters.last().centroid + val drift = 1 - cosineSimilarity(firstCentroid, lastCentroid) + + // If embeddings changed significantly, likely a child + return drift >= CHILD_EMBEDDING_DRIFT_THRESHOLD + } + + /** + * Estimate age in specific year based on cluster position + */ + private fun estimateAgeInYear(targetYear: String, allClusters: List): Int? { + val sortedClusters = allClusters.sortedBy { it.year } + val firstYear = sortedClusters.first().year.toInt() + val targetYearInt = targetYear.toInt() + + val yearsSinceFirst = targetYearInt - firstYear + return yearsSinceFirst + 1 // Start at age 1 + } + + /** + * Select representative faces + */ + private fun selectRepresentativeFaces( + faces: List, + count: Int + ): List { + if (faces.size <= count) return faces + + val centroid = calculateCentroid(faces.map { it.embedding }) + + return faces + .map { face -> face to (1 - cosineSimilarity(face.embedding, centroid)) } + .sortedBy { it.second } + .take(count) + .map { it.first } + } + + /** + * DBSCAN clustering + */ + private fun performDBSCAN( + faces: List, + epsilon: Float, + minPoints: Int + ): List { + val visited = mutableSetOf() + val clusters = mutableListOf() + var clusterId = 0 + + for (i in faces.indices) { + if (i in visited) continue + + val neighbors = findNeighbors(i, faces, epsilon) + + if (neighbors.size < minPoints) { + visited.add(i) + continue + } + + val cluster = mutableListOf() + val queue = ArrayDeque(neighbors) + + while (queue.isNotEmpty()) { + val pointIdx = queue.removeFirst() + if (pointIdx in visited) continue + + visited.add(pointIdx) + cluster.add(faces[pointIdx]) + + val pointNeighbors = findNeighbors(pointIdx, faces, epsilon) + if (pointNeighbors.size >= minPoints) { + queue.addAll(pointNeighbors.filter { it !in visited }) + } + } + + if (cluster.size >= minPoints) { + clusters.add(RawCluster(clusterId++, cluster)) + } + } + + return clusters + } + + private fun findNeighbors( + pointIdx: Int, + faces: List, + epsilon: Float + ): List { + val point = faces[pointIdx] + return faces.indices.filter { i -> + if (i == pointIdx) return@filter false + val similarity = cosineSimilarity(point.embedding, faces[i].embedding) + similarity > (1 - epsilon) + } + } + + private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float { + var dotProduct = 0f + var normA = 0f + var normB = 0f + + for (i in a.indices) { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + + return dotProduct / (sqrt(normA) * sqrt(normB)) + } + + private fun calculateCentroid(embeddings: List): FloatArray { + if (embeddings.isEmpty()) return FloatArray(0) + + val size = embeddings.first().size + val centroid = FloatArray(size) { 0f } + + embeddings.forEach { embedding -> + for (i in embedding.indices) { + centroid[i] += embedding[i] + } + } + + val count = embeddings.size.toFloat() + for (i in centroid.indices) { + centroid[i] /= count + } + + val norm = sqrt(centroid.map { it * it }.sum()) + if (norm > 0) { + return centroid.map { it / norm }.toFloatArray() + } + + return centroid + } + + private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): Bitmap? { + return try { + val opts = BitmapFactory.Options().apply { inJustDecodeBounds = true } + context.contentResolver.openInputStream(uri)?.use { + BitmapFactory.decodeStream(it, null, opts) + } + + var sample = 1 + while (opts.outWidth / sample > maxDim || opts.outHeight / sample > maxDim) { + sample *= 2 + } + + val finalOpts = BitmapFactory.Options().apply { + inSampleSize = sample + inPreferredConfig = Bitmap.Config.RGB_565 + } + + context.contentResolver.openInputStream(uri)?.use { + BitmapFactory.decodeStream(it, null, finalOpts) + } + } catch (e: Exception) { + null + } + } +} + +/** + * Year-specific cluster + */ +data class YearCluster( + val year: String, + val faces: List, + val centroid: FloatArray +) + +/** + * Group of year clusters belonging to same person + */ +data class PersonGroup( + val clusters: List +) + +/** + * Annotated cluster with temporal metadata + */ +data class AnnotatedCluster( + val cluster: FaceCluster, + val year: String, + val isChild: Boolean, + val suggestedName: String?, + val suggestedAge: Int? +) { + /** + * Generate tag for this cluster + * Examples: + * - Child: "Emma_2020" or "Emma_Age_2" + * - Adult: "Brad_Pitt" + */ + fun generateTag(name: String): String { + return if (isChild) { + if (suggestedAge != null) { + "${name}_Age_${suggestedAge}" + } else { + "${name}_${year}" + } + } else { + name + } + } +} + +/** + * Result of temporal clustering + */ +data class TemporalClusteringResult( + val clusters: List, + val totalPhotosProcessed: Int, + val totalFacesDetected: Int, + val processingTimeMs: Long, + val errorMessage: String? = null +) \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/domain/training/Clustertrainingservice.kt b/app/src/main/java/com/placeholder/sherpai2/domain/training/Clustertrainingservice.kt new file mode 100644 index 0000000..32d8de9 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/domain/training/Clustertrainingservice.kt @@ -0,0 +1,253 @@ +package com.placeholder.sherpai2.domain.training + +import android.content.Context +import android.graphics.BitmapFactory +import android.net.Uri +import com.placeholder.sherpai2.data.local.dao.FaceModelDao +import com.placeholder.sherpai2.data.local.dao.PersonDao +import com.placeholder.sherpai2.data.local.entity.FaceModelEntity +import com.placeholder.sherpai2.data.local.entity.PersonEntity +import com.placeholder.sherpai2.data.local.entity.TemporalCentroid +import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer +import com.placeholder.sherpai2.domain.clustering.ClusterQualityResult +import com.placeholder.sherpai2.domain.clustering.FaceCluster +import com.placeholder.sherpai2.ml.FaceNetModel +import dagger.hilt.android.qualifiers.ApplicationContext +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import javax.inject.Inject +import javax.inject.Singleton +import kotlin.math.abs + +/** + * ClusterTrainingService - Train multi-centroid face models from clusters + * + * STRATEGY: + * 1. VALIDATE cluster quality FIRST (prevent training on dirty/mixed clusters) + * 2. For children: Create multiple temporal centroids (one per age period) + * 3. For adults: Create single centroid (stable appearance) + * 4. Use K-Means clustering on timestamps to find age groups + * 5. Calculate centroid for each time period + */ +@Singleton +class ClusterTrainingService @Inject constructor( + @ApplicationContext private val context: Context, + private val personDao: PersonDao, + private val faceModelDao: FaceModelDao, + private val qualityAnalyzer: ClusterQualityAnalyzer +) { + + private val faceNetModel by lazy { FaceNetModel(context) } + + /** + * Analyze cluster quality before training + * + * Call this BEFORE trainFromCluster() to check if cluster is clean + */ + suspend fun analyzeClusterQuality(cluster: FaceCluster): ClusterQualityResult { + return qualityAnalyzer.analyzeCluster(cluster) + } + + /** + * Train a person from an auto-discovered cluster + * + * @param cluster The discovered cluster + * @param qualityResult Optional pre-computed quality analysis (recommended) + * @return PersonId on success + */ + suspend fun trainFromCluster( + cluster: FaceCluster, + name: String, + dateOfBirth: Long?, + isChild: Boolean, + siblingClusterIds: List, + qualityResult: ClusterQualityResult? = null, + onProgress: (Int, Int, String) -> Unit = { _, _, _ -> } + ): String = withContext(Dispatchers.Default) { + + onProgress(0, 100, "Creating person...") + + // Step 1: Use clean faces if quality analysis was done + val facesToUse = if (qualityResult != null && qualityResult.cleanFaces.isNotEmpty()) { + // Use clean faces (outliers removed) + qualityResult.cleanFaces + } else { + // Use all faces (legacy behavior) + cluster.faces + } + + if (facesToUse.size < 6) { + throw Exception("Need at least 6 clean faces for training (have ${facesToUse.size})") + } + + // Step 2: Create PersonEntity + val person = PersonEntity.create( + name = name, + dateOfBirth = dateOfBirth, + isChild = isChild, + siblingIds = emptyList(), // Will update after siblings are created + relationship = if (isChild) "Child" else null + ) + + withContext(Dispatchers.IO) { + personDao.insert(person) + } + + onProgress(20, 100, "Analyzing face variations...") + + // Step 3: Use pre-computed embeddings from clustering + // CRITICAL: These embeddings are already face-specific, even in group photos! + // The clustering phase already cropped and generated embeddings for each face. + val facesWithEmbeddings = facesToUse.map { face -> + Triple( + face.imageUri, + face.capturedAt, + face.embedding // ✅ Use existing embedding (already cropped to face) + ) + } + + onProgress(50, 100, "Creating face model...") + + // Step 4: Create centroids based on whether person is a child + val centroids = if (isChild && dateOfBirth != null) { + createTemporalCentroidsForChild( + facesWithEmbeddings = facesWithEmbeddings, + dateOfBirth = dateOfBirth + ) + } else { + createSingleCentroid(facesWithEmbeddings) + } + + onProgress(80, 100, "Saving model...") + + // Step 5: Calculate average confidence + val avgConfidence = centroids.map { it.avgConfidence }.average().toFloat() + + // Step 6: Create FaceModelEntity + val faceModel = FaceModelEntity.createFromCentroids( + personId = person.id, + centroids = centroids, + trainingImageCount = facesToUse.size, + averageConfidence = avgConfidence + ) + + withContext(Dispatchers.IO) { + faceModelDao.insertFaceModel(faceModel) + } + + onProgress(100, 100, "Complete!") + + person.id + } + + /** + * Create temporal centroids for a child + * Groups faces by age and creates one centroid per age period + */ + private fun createTemporalCentroidsForChild( + facesWithEmbeddings: List>, + dateOfBirth: Long + ): List { + + // Group faces by age (in years) + val facesByAge = facesWithEmbeddings.groupBy { (_, capturedAt, _) -> + val ageMs = capturedAt - dateOfBirth + val ageYears = (ageMs / (365.25 * 24 * 60 * 60 * 1000)).toInt() + ageYears.coerceIn(0, 18) // Cap at 18 years + } + + // Create one centroid per age group + return facesByAge.map { (age, faces) -> + val embeddings = faces.map { it.third } + val avgEmbedding = averageEmbeddings(embeddings) + val avgTimestamp = faces.map { it.second }.average().toLong() + + // Calculate confidence (how similar faces are to each other) + val confidences = embeddings.map { emb -> + cosineSimilarity(avgEmbedding, emb) + } + val avgConfidence = confidences.average().toFloat() + + TemporalCentroid( + embedding = avgEmbedding.toList(), + effectiveTimestamp = avgTimestamp, + ageAtCapture = age.toFloat(), + photoCount = faces.size, + timeRangeMonths = 12, // 1 year window + avgConfidence = avgConfidence + ) + }.sortedBy { it.ageAtCapture } + } + + /** + * Create single centroid for an adult (stable appearance) + */ + private fun createSingleCentroid( + facesWithEmbeddings: List> + ): List { + + val embeddings = facesWithEmbeddings.map { it.third } + val avgEmbedding = averageEmbeddings(embeddings) + val avgTimestamp = facesWithEmbeddings.map { it.second }.average().toLong() + + val confidences = embeddings.map { emb -> + cosineSimilarity(avgEmbedding, emb) + } + val avgConfidence = confidences.average().toFloat() + + return listOf( + TemporalCentroid( + embedding = avgEmbedding.toList(), + effectiveTimestamp = avgTimestamp, + ageAtCapture = null, + photoCount = facesWithEmbeddings.size, + timeRangeMonths = 24, // 2 year window for adults + avgConfidence = avgConfidence + ) + ) + } + + /** + * Average multiple embeddings into one + */ + private fun averageEmbeddings(embeddings: List): FloatArray { + val size = embeddings.first().size + val avg = FloatArray(size) { 0f } + + embeddings.forEach { embedding -> + for (i in embedding.indices) { + avg[i] += embedding[i] + } + } + + val count = embeddings.size.toFloat() + for (i in avg.indices) { + avg[i] /= count + } + + // Normalize to unit length + val norm = kotlin.math.sqrt(avg.map { it * it }.sum()) + return avg.map { it / norm }.toFloatArray() + } + + /** + * Calculate cosine similarity between two embeddings + */ + private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float { + var dotProduct = 0f + var normA = 0f + var normB = 0f + + for (i in a.indices) { + dotProduct += a[i] * b[i] + normA += a[i] * a[i] + normB += b[i] * b[i] + } + + return dotProduct / (kotlin.math.sqrt(normA) * kotlin.math.sqrt(normB)) + } + + fun cleanup() { + faceNetModel.close() + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/domain/usecase/Populatefacedetectioncacheusecase.kt b/app/src/main/java/com/placeholder/sherpai2/domain/usecase/Populatefacedetectioncacheusecase.kt index b8697f5..ff5fabd 100644 --- a/app/src/main/java/com/placeholder/sherpai2/domain/usecase/Populatefacedetectioncacheusecase.kt +++ b/app/src/main/java/com/placeholder/sherpai2/domain/usecase/Populatefacedetectioncacheusecase.kt @@ -1,7 +1,13 @@ package com.placeholder.sherpai2.domain.usecase import android.content.Context +import android.graphics.Bitmap +import android.util.Log +import com.google.mlkit.vision.face.Face +import com.placeholder.sherpai2.data.local.dao.FaceCacheDao import com.placeholder.sherpai2.data.local.dao.ImageDao +import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity +import com.placeholder.sherpai2.data.local.entity.ImageEntity import dagger.hilt.android.qualifiers.ApplicationContext import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async @@ -15,41 +21,56 @@ import kotlinx.coroutines.withContext import java.util.concurrent.atomic.AtomicInteger import javax.inject.Inject import javax.inject.Singleton +import kotlin.math.abs /** - * PopulateFaceDetectionCache - HYPER-PARALLEL face scanning + * PopulateFaceDetectionCache - ENHANCED VERSION * - * STRATEGY: Use ACCURATE mode BUT with MASSIVE parallelization - * - 50 concurrent detections (not 10!) - * - Semaphore limits to prevent OOM - * - Atomic counters for thread-safe progress - * - Smaller images (768px) for speed without quality loss + * NOW POPULATES TWO CACHES: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * 1. ImageEntity cache (hasFaces, faceCount) - for quick filters + * 2. FaceCacheEntity table - for Discovery pre-filtering * - * RESULT: ~2000-3000 images/minute on modern phones + * SAME ML KIT SCAN - Just saves more data! + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * Previously: One scan → saves 2 fields (hasFaces, faceCount) + * Now: One scan → saves 2 fields + full face metadata! + * + * RESULT: Discovery can skip Path 3 (8 min) and use Path 2 (3 min) */ @Singleton class PopulateFaceDetectionCacheUseCase @Inject constructor( @ApplicationContext private val context: Context, - private val imageDao: ImageDao + private val imageDao: ImageDao, + private val faceCacheDao: FaceCacheDao ) { - // Limit concurrent operations to prevent OOM - private val semaphore = Semaphore(50) // 50 concurrent detections! + companion object { + private const val TAG = "FaceCachePopulation" + private const val SEMAPHORE_PERMITS = 50 + private const val BATCH_SIZE = 100 + } + + private val semaphore = Semaphore(SEMAPHORE_PERMITS) /** - * HYPER-PARALLEL face detection with ACCURATE mode + * ENHANCED: Populates BOTH image cache AND face metadata cache */ suspend fun execute( onProgress: (Int, Int, String?) -> Unit = { _, _, _ -> } ): Int = withContext(Dispatchers.IO) { - // Create detector with ACCURATE mode but optimized settings + Log.d(TAG, "════════════════════════════════════════") + Log.d(TAG, "Enhanced Face Cache Population Started") + Log.d(TAG, "Populating: ImageEntity + FaceCacheEntity") + Log.d(TAG, "════════════════════════════════════════") + val detector = com.google.mlkit.vision.face.FaceDetection.getClient( com.google.mlkit.vision.face.FaceDetectorOptions.Builder() .setPerformanceMode(com.google.mlkit.vision.face.FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) - .setLandmarkMode(com.google.mlkit.vision.face.FaceDetectorOptions.LANDMARK_MODE_NONE) // Don't need landmarks for cache - .setClassificationMode(com.google.mlkit.vision.face.FaceDetectorOptions.CLASSIFICATION_MODE_NONE) // Don't need classification - .setMinFaceSize(0.1f) // Detect smaller faces + .setLandmarkMode(com.google.mlkit.vision.face.FaceDetectorOptions.LANDMARK_MODE_ALL) + .setClassificationMode(com.google.mlkit.vision.face.FaceDetectorOptions.CLASSIFICATION_MODE_NONE) + .setMinFaceSize(0.1f) .build() ) @@ -57,44 +78,34 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor( val imagesToScan = imageDao.getImagesNeedingFaceDetection() if (imagesToScan.isEmpty()) { + Log.d(TAG, "No images need scanning") return@withContext 0 } + Log.d(TAG, "Scanning ${imagesToScan.size} images") + val total = imagesToScan.size val scanned = AtomicInteger(0) - val pendingUpdates = mutableListOf() - val updatesMutex = kotlinx.coroutines.sync.Mutex() + val pendingImageUpdates = mutableListOf() + val pendingFaceCacheUpdates = mutableListOf() + val updatesMutex = Mutex() - // Process ALL images in parallel with semaphore control + // Process all images in parallel coroutineScope { val jobs = imagesToScan.map { image -> async(Dispatchers.Default) { semaphore.acquire() try { - // Load bitmap with medium downsampling (768px = good balance) - val bitmap = loadBitmapOptimized(android.net.Uri.parse(image.imageUri)) - - if (bitmap == null) { - return@async CacheUpdate(image.imageId, false, 0, image.imageUri) - } - - // Detect faces - val inputImage = com.google.mlkit.vision.common.InputImage.fromBitmap(bitmap, 0) - val faces = detector.process(inputImage).await() - bitmap.recycle() - - CacheUpdate( - imageId = image.imageId, - hasFaces = faces.isNotEmpty(), - faceCount = faces.size, - imageUri = image.imageUri - ) + processImage(image, detector) } catch (e: Exception) { - CacheUpdate(image.imageId, false, 0, image.imageUri) + Log.w(TAG, "Error processing ${image.imageId}: ${e.message}") + ScanResult( + ImageCacheUpdate(image.imageId, false, 0, image.imageUri), + emptyList() + ) } finally { semaphore.release() - // Update progress val current = scanned.incrementAndGet() if (current % 50 == 0 || current == total) { onProgress(current, total, image.imageUri) @@ -103,27 +114,42 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor( } } - // Wait for all to complete and collect results - jobs.awaitAll().forEach { update -> + // Collect results + jobs.awaitAll().forEach { result -> updatesMutex.withLock { - pendingUpdates.add(update) + pendingImageUpdates.add(result.imageCacheUpdate) + pendingFaceCacheUpdates.addAll(result.faceCacheEntries) - // Batch write to DB every 100 updates - if (pendingUpdates.size >= 100) { - flushUpdates(pendingUpdates.toList()) - pendingUpdates.clear() + // Batch write to DB + if (pendingImageUpdates.size >= BATCH_SIZE) { + flushUpdates( + pendingImageUpdates.toList(), + pendingFaceCacheUpdates.toList() + ) + pendingImageUpdates.clear() + pendingFaceCacheUpdates.clear() } } } // Flush remaining updatesMutex.withLock { - if (pendingUpdates.isNotEmpty()) { - flushUpdates(pendingUpdates) + if (pendingImageUpdates.isNotEmpty()) { + flushUpdates(pendingImageUpdates, pendingFaceCacheUpdates) } } } + val totalFacesCached = withContext(Dispatchers.IO) { + faceCacheDao.getCacheStats().totalFaces + } + + Log.d(TAG, "════════════════════════════════════════") + Log.d(TAG, "Cache Population Complete!") + Log.d(TAG, "Images scanned: ${scanned.get()}") + Log.d(TAG, "Faces cached: $totalFacesCached") + Log.d(TAG, "════════════════════════════════════════") + scanned.get() } finally { detector.close() @@ -131,11 +157,94 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor( } /** - * Optimized bitmap loading with configurable max dimension + * Process a single image - detect faces and create cache entries */ - private fun loadBitmapOptimized(uri: android.net.Uri, maxDim: Int = 768): android.graphics.Bitmap? { + private suspend fun processImage( + image: ImageEntity, + detector: com.google.mlkit.vision.face.FaceDetector + ): ScanResult { + val bitmap = loadBitmapOptimized(android.net.Uri.parse(image.imageUri)) + ?: return ScanResult( + ImageCacheUpdate(image.imageId, false, 0, image.imageUri), + emptyList() + ) + + try { + val inputImage = com.google.mlkit.vision.common.InputImage.fromBitmap(bitmap, 0) + val faces = detector.process(inputImage).await() + + val imageWidth = bitmap.width + val imageHeight = bitmap.height + + // Create ImageEntity cache update + val imageCacheUpdate = ImageCacheUpdate( + imageId = image.imageId, + hasFaces = faces.isNotEmpty(), + faceCount = faces.size, + imageUri = image.imageUri + ) + + // Create FaceCacheEntity entries for each face + val faceCacheEntries = faces.mapIndexed { index, face -> + createFaceCacheEntry( + imageId = image.imageId, + faceIndex = index, + face = face, + imageWidth = imageWidth, + imageHeight = imageHeight + ) + } + + return ScanResult(imageCacheUpdate, faceCacheEntries) + + } finally { + bitmap.recycle() + } + } + + /** + * Create FaceCacheEntity from ML Kit Face + * + * Uses FaceCacheEntity.create() which calculates quality metrics automatically + */ + private fun createFaceCacheEntry( + imageId: String, + faceIndex: Int, + face: Face, + imageWidth: Int, + imageHeight: Int + ): FaceCacheEntity { + // Determine if frontal based on head rotation + val isFrontal = isFrontalFace(face) + + return FaceCacheEntity.create( + imageId = imageId, + faceIndex = faceIndex, + boundingBox = face.boundingBox, + imageWidth = imageWidth, + imageHeight = imageHeight, + confidence = 0.9f, // High confidence from accurate detector + isFrontal = isFrontal, + embedding = null // Will be generated later during Discovery + ) + } + + /** + * Check if face is frontal + */ + private fun isFrontalFace(face: Face): Boolean { + val eulerY = face.headEulerAngleY + val eulerZ = face.headEulerAngleZ + + // Frontal if head rotation is within 20 degrees + return abs(eulerY) <= 20f && abs(eulerZ) <= 20f + } + + /** + * Optimized bitmap loading + */ + private fun loadBitmapOptimized(uri: android.net.Uri, maxDim: Int = 768): Bitmap? { return try { - // Get dimensions val options = android.graphics.BitmapFactory.Options().apply { inJustDecodeBounds = true } @@ -143,40 +252,54 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor( android.graphics.BitmapFactory.decodeStream(stream, null, options) } - // Calculate sample size var sampleSize = 1 while (options.outWidth / sampleSize > maxDim || options.outHeight / sampleSize > maxDim) { sampleSize *= 2 } - // Load with sample size val finalOptions = android.graphics.BitmapFactory.Options().apply { inSampleSize = sampleSize - inPreferredConfig = android.graphics.Bitmap.Config.ARGB_8888 // Better quality + inPreferredConfig = android.graphics.Bitmap.Config.ARGB_8888 } context.contentResolver.openInputStream(uri)?.use { stream -> android.graphics.BitmapFactory.decodeStream(stream, null, finalOptions) } } catch (e: Exception) { + Log.w(TAG, "Failed to load bitmap: ${e.message}") null } } /** - * Batch DB update + * Batch update both caches */ - private suspend fun flushUpdates(updates: List) = withContext(Dispatchers.IO) { - updates.forEach { update -> + private suspend fun flushUpdates( + imageUpdates: List, + faceUpdates: List + ) = withContext(Dispatchers.IO) { + // Update ImageEntity cache + imageUpdates.forEach { update -> try { imageDao.updateFaceDetectionCache( imageId = update.imageId, hasFaces = update.hasFaces, - faceCount = update.faceCount + faceCount = update.faceCount, + timestamp = System.currentTimeMillis(), + version = ImageEntity.CURRENT_FACE_DETECTION_VERSION ) } catch (e: Exception) { - // Skip failed updates + Log.w(TAG, "Failed to update image cache: ${e.message}") + } + } + + // Insert FaceCacheEntity entries + if (faceUpdates.isNotEmpty()) { + try { + faceCacheDao.insertAll(faceUpdates) + } catch (e: Exception) { + Log.e(TAG, "Failed to insert face cache entries: ${e.message}") } } } @@ -186,36 +309,53 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor( } suspend fun getCacheStats(): CacheStats = withContext(Dispatchers.IO) { - val stats = imageDao.getFaceCacheStats() + val imageStats = imageDao.getFaceCacheStats() + val faceStats = faceCacheDao.getCacheStats() + CacheStats( - totalImages = stats?.totalImages ?: 0, - imagesWithFaceCache = stats?.imagesWithFaceCache ?: 0, - imagesWithFaces = stats?.imagesWithFaces ?: 0, - imagesWithoutFaces = stats?.imagesWithoutFaces ?: 0, - needsScanning = stats?.needsScanning ?: 0 + totalImages = imageStats?.totalImages ?: 0, + imagesWithFaceCache = imageStats?.imagesWithFaceCache ?: 0, + imagesWithFaces = imageStats?.imagesWithFaces ?: 0, + imagesWithoutFaces = imageStats?.imagesWithoutFaces ?: 0, + needsScanning = imageStats?.needsScanning ?: 0, + totalFacesCached = faceStats.totalFaces, + facesWithEmbeddings = faceStats.withEmbeddings, + averageQuality = faceStats.avgQuality ) } } -private data class CacheUpdate( +/** + * Result of scanning a single image + */ +private data class ScanResult( + val imageCacheUpdate: ImageCacheUpdate, + val faceCacheEntries: List +) + +/** + * Image cache update data + */ +private data class ImageCacheUpdate( val imageId: String, val hasFaces: Boolean, val faceCount: Int, val imageUri: String ) +/** + * Enhanced cache stats + */ data class CacheStats( val totalImages: Int, val imagesWithFaceCache: Int, val imagesWithFaces: Int, val imagesWithoutFaces: Int, - val needsScanning: Int + val needsScanning: Int, + val totalFacesCached: Int, + val facesWithEmbeddings: Int, + val averageQuality: Float ) { - val cacheProgress: Float - get() = if (totalImages > 0) { - imagesWithFaceCache.toFloat() / totalImages.toFloat() - } else 0f - val isComplete: Boolean get() = needsScanning == 0 } \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/domain/validation/Validationscanservice.kt b/app/src/main/java/com/placeholder/sherpai2/domain/validation/Validationscanservice.kt new file mode 100644 index 0000000..8ed92a0 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/domain/validation/Validationscanservice.kt @@ -0,0 +1,312 @@ +package com.placeholder.sherpai2.domain.validation + +import android.content.Context +import android.graphics.BitmapFactory +import android.net.Uri +import com.google.mlkit.vision.common.InputImage +import com.google.mlkit.vision.face.FaceDetection +import com.google.mlkit.vision.face.FaceDetectorOptions +import com.placeholder.sherpai2.data.local.dao.FaceModelDao +import com.placeholder.sherpai2.data.local.dao.ImageDao +import com.placeholder.sherpai2.data.local.entity.FaceModelEntity +import com.placeholder.sherpai2.data.local.entity.ImageEntity +import com.placeholder.sherpai2.ml.FaceNetModel +import dagger.hilt.android.qualifiers.ApplicationContext +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.tasks.await +import kotlinx.coroutines.withContext +import javax.inject.Inject +import javax.inject.Singleton + +/** + * ValidationScanService - Quick validation scan after training + * + * PURPOSE: Let user verify model quality BEFORE full library scan + * + * STRATEGY: + * 1. Sample 20-30 random photos with faces + * 2. Scan for the newly trained person + * 3. Return preview results with confidence scores + * 4. User reviews and decides: "Looks good" or "Add more photos" + * + * THRESHOLD STRATEGY: + * - Use CONSERVATIVE threshold (0.75) for validation + * - Better to show false negatives than false positives + * - If user approves, full scan uses slightly looser threshold (0.70) + */ +@Singleton +class ValidationScanService @Inject constructor( + @ApplicationContext private val context: Context, + private val imageDao: ImageDao, + private val faceModelDao: FaceModelDao +) { + + companion object { + private const val VALIDATION_SAMPLE_SIZE = 25 + private const val VALIDATION_THRESHOLD = 0.75f // Conservative + } + + /** + * Perform validation scan after training + * + * @param personId The newly trained person + * @param onProgress Callback (current, total) + * @return Validation results with preview matches + */ + suspend fun performValidationScan( + personId: String, + onProgress: (Int, Int) -> Unit = { _, _ -> } + ): ValidationScanResult = withContext(Dispatchers.Default) { + + onProgress(0, 100) + + // Step 1: Get face model + val faceModel = withContext(Dispatchers.IO) { + faceModelDao.getFaceModelByPersonId(personId) + } ?: return@withContext ValidationScanResult( + personId = personId, + matches = emptyList(), + sampleSize = 0, + errorMessage = "Face model not found" + ) + + onProgress(10, 100) + + // Step 2: Get random sample of photos with faces + val allPhotosWithFaces = withContext(Dispatchers.IO) { + imageDao.getImagesWithFaces() + } + + if (allPhotosWithFaces.isEmpty()) { + return@withContext ValidationScanResult( + personId = personId, + matches = emptyList(), + sampleSize = 0, + errorMessage = "No photos with faces in library" + ) + } + + // Random sample + val samplePhotos = allPhotosWithFaces.shuffled().take(VALIDATION_SAMPLE_SIZE) + onProgress(20, 100) + + // Step 3: Scan sample photos + val faceNetModel = FaceNetModel(context) + val detector = FaceDetection.getClient( + FaceDetectorOptions.Builder() + .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) + .setMinFaceSize(0.15f) + .build() + ) + + try { + val matches = scanPhotosForPerson( + photos = samplePhotos, + faceModel = faceModel, + faceNetModel = faceNetModel, + detector = detector, + threshold = VALIDATION_THRESHOLD, + onProgress = { current, total -> + // Map to 20-100 range + val progress = 20 + (current * 80 / total) + onProgress(progress, 100) + } + ) + + onProgress(100, 100) + + ValidationScanResult( + personId = personId, + matches = matches, + sampleSize = samplePhotos.size, + threshold = VALIDATION_THRESHOLD + ) + + } finally { + faceNetModel.close() + detector.close() + } + } + + /** + * Scan photos for a specific person + */ + private suspend fun scanPhotosForPerson( + photos: List, + faceModel: FaceModelEntity, + faceNetModel: FaceNetModel, + detector: com.google.mlkit.vision.face.FaceDetector, + threshold: Float, + onProgress: (Int, Int) -> Unit + ): List = coroutineScope { + + val modelEmbedding = faceModel.getEmbeddingArray() + val matches = mutableListOf() + var processedCount = 0 + + // Process in parallel + val jobs = photos.map { photo -> + async(Dispatchers.IO) { + val photoMatches = scanSinglePhoto( + photo = photo, + modelEmbedding = modelEmbedding, + faceNetModel = faceNetModel, + detector = detector, + threshold = threshold + ) + + synchronized(matches) { + matches.addAll(photoMatches) + processedCount++ + if (processedCount % 5 == 0) { + onProgress(processedCount, photos.size) + } + } + } + } + + jobs.awaitAll() + matches.sortedByDescending { it.confidence } + } + + /** + * Scan a single photo for the person + */ + private suspend fun scanSinglePhoto( + photo: ImageEntity, + modelEmbedding: FloatArray, + faceNetModel: FaceNetModel, + detector: com.google.mlkit.vision.face.FaceDetector, + threshold: Float + ): List = withContext(Dispatchers.IO) { + + try { + // Load bitmap + val bitmap = loadBitmapDownsampled(Uri.parse(photo.imageUri), 768) + ?: return@withContext emptyList() + + // Detect faces + val inputImage = InputImage.fromBitmap(bitmap, 0) + val faces = detector.process(inputImage).await() + + // Check each face + val matches = faces.mapNotNull { face -> + try { + // Crop face + val faceBitmap = android.graphics.Bitmap.createBitmap( + bitmap, + face.boundingBox.left.coerceIn(0, bitmap.width - 1), + face.boundingBox.top.coerceIn(0, bitmap.height - 1), + face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left), + face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top) + ) + + // Generate embedding + val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap) + faceBitmap.recycle() + + // Calculate similarity + val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding) + + if (similarity >= threshold) { + ValidationMatch( + imageId = photo.imageId, + imageUri = photo.imageUri, + capturedAt = photo.capturedAt, + confidence = similarity, + boundingBox = face.boundingBox, + faceCount = faces.size + ) + } else { + null + } + } catch (e: Exception) { + null + } + } + + bitmap.recycle() + matches + + } catch (e: Exception) { + emptyList() + } + } + + /** + * Load bitmap with downsampling + */ + private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): android.graphics.Bitmap? { + return try { + val opts = BitmapFactory.Options().apply { inJustDecodeBounds = true } + context.contentResolver.openInputStream(uri)?.use { + BitmapFactory.decodeStream(it, null, opts) + } + + var sample = 1 + while (opts.outWidth / sample > maxDim || opts.outHeight / sample > maxDim) { + sample *= 2 + } + + val finalOpts = BitmapFactory.Options().apply { + inSampleSize = sample + } + + context.contentResolver.openInputStream(uri)?.use { + BitmapFactory.decodeStream(it, null, finalOpts) + } + } catch (e: Exception) { + null + } + } +} + +/** + * Result of validation scan + */ +data class ValidationScanResult( + val personId: String, + val matches: List, + val sampleSize: Int, + val threshold: Float = 0.75f, + val errorMessage: String? = null +) { + val matchCount: Int get() = matches.size + val averageConfidence: Float get() = if (matches.isNotEmpty()) { + matches.map { it.confidence }.average().toFloat() + } else 0f + + val qualityAssessment: ValidationQuality get() = when { + matchCount == 0 -> ValidationQuality.NO_MATCHES + averageConfidence >= 0.85f && matchCount >= 5 -> ValidationQuality.EXCELLENT + averageConfidence >= 0.78f && matchCount >= 3 -> ValidationQuality.GOOD + averageConfidence < 0.75f || matchCount < 2 -> ValidationQuality.POOR + else -> ValidationQuality.FAIR + } +} + +/** + * Single match found during validation + */ +data class ValidationMatch( + val imageId: String, + val imageUri: String, + val capturedAt: Long, + val confidence: Float, + val boundingBox: android.graphics.Rect, + val faceCount: Int +) + +/** + * Overall quality assessment + */ +enum class ValidationQuality { + EXCELLENT, // High confidence, many matches + GOOD, // Decent confidence, some matches + FAIR, // Acceptable, proceed with caution + POOR, // Low confidence or very few matches + NO_MATCHES // No matches found at all +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt b/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt index 22daad1..a15197d 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt @@ -2,6 +2,7 @@ package com.placeholder.sherpai2.ml import android.content.Context import android.graphics.Bitmap +import android.util.Log import org.tensorflow.lite.Interpreter import java.io.FileInputStream import java.nio.ByteBuffer @@ -11,16 +12,21 @@ import java.nio.channels.FileChannel import kotlin.math.sqrt /** - * FaceNetModel - MobileFaceNet wrapper for face recognition + * FaceNetModel - MobileFaceNet wrapper with debugging * - * CLEAN IMPLEMENTATION: - * - All IDs are Strings (matching your schema) - * - Generates 192-dimensional embeddings - * - Cosine similarity for matching + * IMPROVEMENTS: + * - ✅ Detailed error logging + * - ✅ Model validation on init + * - ✅ Embedding validation (detect all-zeros) + * - ✅ Toggle-able debug mode */ -class FaceNetModel(private val context: Context) { +class FaceNetModel( + private val context: Context, + private val debugMode: Boolean = true // Enable for troubleshooting +) { companion object { + private const val TAG = "FaceNetModel" private const val MODEL_FILE = "mobilefacenet.tflite" private const val INPUT_SIZE = 112 private const val EMBEDDING_SIZE = 192 @@ -31,13 +37,56 @@ class FaceNetModel(private val context: Context) { } private var interpreter: Interpreter? = null + private var modelLoadSuccess = false init { try { + if (debugMode) Log.d(TAG, "Loading FaceNet model: $MODEL_FILE") + val model = loadModelFile() interpreter = Interpreter(model) + modelLoadSuccess = true + + if (debugMode) { + Log.d(TAG, "✅ FaceNet model loaded successfully") + Log.d(TAG, "Model input size: ${INPUT_SIZE}x$INPUT_SIZE") + Log.d(TAG, "Embedding size: $EMBEDDING_SIZE") + } + + // Test model with dummy input + testModel() + } catch (e: Exception) { - throw RuntimeException("Failed to load FaceNet model", e) + Log.e(TAG, "❌ CRITICAL: Failed to load FaceNet model from assets/$MODEL_FILE", e) + Log.e(TAG, "Make sure mobilefacenet.tflite exists in app/src/main/assets/") + modelLoadSuccess = false + throw RuntimeException("Failed to load FaceNet model: ${e.message}", e) + } + } + + /** + * Test model with dummy input to verify it works + */ + private fun testModel() { + try { + val testBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Bitmap.Config.ARGB_8888) + val testEmbedding = generateEmbedding(testBitmap) + testBitmap.recycle() + + val sum = testEmbedding.sum() + val norm = sqrt(testEmbedding.map { it * it }.sum()) + + if (debugMode) { + Log.d(TAG, "Model test: embedding sum=$sum, norm=$norm") + } + + if (sum == 0f || norm == 0f) { + Log.e(TAG, "⚠️ WARNING: Model test produced zero embedding!") + } else { + if (debugMode) Log.d(TAG, "✅ Model test passed") + } + } catch (e: Exception) { + Log.e(TAG, "Model test failed", e) } } @@ -45,12 +94,22 @@ class FaceNetModel(private val context: Context) { * Load TFLite model from assets */ private fun loadModelFile(): MappedByteBuffer { - val fileDescriptor = context.assets.openFd(MODEL_FILE) - val inputStream = FileInputStream(fileDescriptor.fileDescriptor) - val fileChannel = inputStream.channel - val startOffset = fileDescriptor.startOffset - val declaredLength = fileDescriptor.declaredLength - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) + try { + val fileDescriptor = context.assets.openFd(MODEL_FILE) + val inputStream = FileInputStream(fileDescriptor.fileDescriptor) + val fileChannel = inputStream.channel + val startOffset = fileDescriptor.startOffset + val declaredLength = fileDescriptor.declaredLength + + if (debugMode) { + Log.d(TAG, "Model file size: ${declaredLength / 1024}KB") + } + + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) + } catch (e: Exception) { + Log.e(TAG, "Failed to open model file: $MODEL_FILE", e) + throw e + } } /** @@ -60,13 +119,39 @@ class FaceNetModel(private val context: Context) { * @return 192-dimensional embedding */ fun generateEmbedding(faceBitmap: Bitmap): FloatArray { - val resized = Bitmap.createScaledBitmap(faceBitmap, INPUT_SIZE, INPUT_SIZE, true) - val inputBuffer = preprocessImage(resized) - val output = Array(1) { FloatArray(EMBEDDING_SIZE) } + if (!modelLoadSuccess || interpreter == null) { + Log.e(TAG, "❌ Cannot generate embedding: model not loaded!") + return FloatArray(EMBEDDING_SIZE) { 0f } + } - interpreter?.run(inputBuffer, output) + try { + val resized = Bitmap.createScaledBitmap(faceBitmap, INPUT_SIZE, INPUT_SIZE, true) + val inputBuffer = preprocessImage(resized) + val output = Array(1) { FloatArray(EMBEDDING_SIZE) } - return normalizeEmbedding(output[0]) + interpreter?.run(inputBuffer, output) + + val normalized = normalizeEmbedding(output[0]) + + // DIAGNOSTIC: Check embedding quality + if (debugMode) { + val sum = normalized.sum() + val norm = sqrt(normalized.map { it * it }.sum()) + + if (sum == 0f && norm == 0f) { + Log.e(TAG, "❌ CRITICAL: Generated all-zero embedding!") + Log.e(TAG, "Input bitmap: ${faceBitmap.width}x${faceBitmap.height}") + } else { + Log.d(TAG, "✅ Embedding: sum=${"%.2f".format(sum)}, norm=${"%.2f".format(norm)}, first5=[${normalized.take(5).joinToString { "%.3f".format(it) }}]") + } + } + + return normalized + + } catch (e: Exception) { + Log.e(TAG, "Failed to generate embedding", e) + return FloatArray(EMBEDDING_SIZE) { 0f } + } } /** @@ -76,6 +161,10 @@ class FaceNetModel(private val context: Context) { faceBitmaps: List, onProgress: (Int, Int) -> Unit = { _, _ -> } ): List { + if (debugMode) { + Log.d(TAG, "Generating embeddings for ${faceBitmaps.size} faces") + } + return faceBitmaps.mapIndexed { index, bitmap -> onProgress(index + 1, faceBitmaps.size) generateEmbedding(bitmap) @@ -88,6 +177,10 @@ class FaceNetModel(private val context: Context) { fun createPersonModel(embeddings: List): FloatArray { require(embeddings.isNotEmpty()) { "Need at least one embedding" } + if (debugMode) { + Log.d(TAG, "Creating person model from ${embeddings.size} embeddings") + } + val averaged = FloatArray(EMBEDDING_SIZE) { 0f } embeddings.forEach { embedding -> @@ -101,7 +194,14 @@ class FaceNetModel(private val context: Context) { averaged[i] /= count } - return normalizeEmbedding(averaged) + val normalized = normalizeEmbedding(averaged) + + if (debugMode) { + val sum = normalized.sum() + Log.d(TAG, "Person model created: sum=${"%.2f".format(sum)}") + } + + return normalized } /** @@ -110,7 +210,7 @@ class FaceNetModel(private val context: Context) { */ fun calculateSimilarity(embedding1: FloatArray, embedding2: FloatArray): Float { require(embedding1.size == EMBEDDING_SIZE && embedding2.size == EMBEDDING_SIZE) { - "Invalid embedding size" + "Invalid embedding size: ${embedding1.size} vs ${embedding2.size}" } var dotProduct = 0f @@ -123,7 +223,14 @@ class FaceNetModel(private val context: Context) { norm2 += embedding2[i] * embedding2[i] } - return dotProduct / (sqrt(norm1) * sqrt(norm2)) + val similarity = dotProduct / (sqrt(norm1) * sqrt(norm2)) + + if (debugMode && (similarity.isNaN() || similarity.isInfinite())) { + Log.e(TAG, "❌ Invalid similarity: $similarity (norm1=$norm1, norm2=$norm2)") + return 0f + } + + return similarity } /** @@ -151,6 +258,10 @@ class FaceNetModel(private val context: Context) { } } + if (debugMode && bestMatch != null) { + Log.d(TAG, "Best match: ${bestMatch.first} with similarity ${bestMatch.second}") + } + return bestMatch } @@ -169,6 +280,7 @@ class FaceNetModel(private val context: Context) { val g = ((pixel shr 8) and 0xFF) / 255.0f val b = (pixel and 0xFF) / 255.0f + // Normalize to [-1, 1] buffer.putFloat((r - 0.5f) / 0.5f) buffer.putFloat((g - 0.5f) / 0.5f) buffer.putFloat((b - 0.5f) / 0.5f) @@ -190,14 +302,29 @@ class FaceNetModel(private val context: Context) { return if (norm > 0) { FloatArray(embedding.size) { i -> embedding[i] / norm } } else { + Log.w(TAG, "⚠️ Cannot normalize zero embedding") embedding } } + /** + * Get model status for diagnostics + */ + fun getModelStatus(): String { + return if (modelLoadSuccess) { + "✅ Model loaded and operational" + } else { + "❌ Model failed to load - check assets/$MODEL_FILE" + } + } + /** * Clean up resources */ fun close() { + if (debugMode) { + Log.d(TAG, "Closing FaceNet model") + } interpreter?.close() interpreter = null } diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Clustergridscreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Clustergridscreen.kt new file mode 100644 index 0000000..965621d --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Clustergridscreen.kt @@ -0,0 +1,297 @@ +package com.placeholder.sherpai2.ui.discover + +import android.net.Uri +import androidx.compose.foundation.background +import androidx.compose.foundation.border +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.lazy.grid.GridCells +import androidx.compose.foundation.lazy.grid.LazyVerticalGrid +import androidx.compose.foundation.lazy.grid.items +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.Check +import androidx.compose.material.icons.filled.Warning +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.layout.ContentScale +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.unit.dp +import coil.compose.AsyncImage +import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer +import com.placeholder.sherpai2.domain.clustering.ClusterQualityTier +import com.placeholder.sherpai2.domain.clustering.ClusteringResult +import com.placeholder.sherpai2.domain.clustering.FaceCluster + +/** + * ClusterGridScreen - Shows all discovered clusters in 2x2 grid + * + * Each cluster card shows: + * - 2x2 grid of representative faces + * - Photo count + * - Quality badge (Excellent/Good/Poor) + * - Tap to name + * + * IMPROVEMENTS: + * - ✅ Quality badges for each cluster + * - ✅ Visual indicators for trainable vs non-trainable clusters + * - ✅ Better UX with disabled states for poor quality clusters + */ +@Composable +fun ClusterGridScreen( + result: ClusteringResult, + onSelectCluster: (FaceCluster) -> Unit, + modifier: Modifier = Modifier, + qualityAnalyzer: ClusterQualityAnalyzer = remember { ClusterQualityAnalyzer() } +) { + Column( + modifier = modifier + .fillMaxSize() + .padding(16.dp) + ) { + // Header + Text( + text = "Found ${result.clusters.size} ${if (result.clusters.size == 1) "Person" else "People"}", + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold + ) + + Spacer(modifier = Modifier.height(8.dp)) + + Text( + text = "Tap a cluster to name the person", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(16.dp)) + + // Grid of clusters + LazyVerticalGrid( + columns = GridCells.Fixed(2), + horizontalArrangement = Arrangement.spacedBy(12.dp), + verticalArrangement = Arrangement.spacedBy(12.dp) + ) { + items(result.clusters) { cluster -> + // Analyze quality for each cluster + val qualityResult = remember(cluster) { + qualityAnalyzer.analyzeCluster(cluster) + } + + ClusterCard( + cluster = cluster, + qualityTier = qualityResult.qualityTier, + canTrain = qualityResult.canTrain, + onClick = { onSelectCluster(cluster) } + ) + } + } + } +} + +/** + * Single cluster card with 2x2 face grid and quality badge + */ +@Composable +private fun ClusterCard( + cluster: FaceCluster, + qualityTier: ClusterQualityTier, + canTrain: Boolean, + onClick: () -> Unit +) { + Card( + modifier = Modifier + .fillMaxWidth() + .aspectRatio(1f) + .clickable(onClick = onClick), // Always clickable - let dialog handle validation + elevation = CardDefaults.cardElevation(defaultElevation = 2.dp), + colors = CardDefaults.cardColors( + containerColor = when { + qualityTier == ClusterQualityTier.POOR -> + MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f) + !canTrain -> + MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + else -> + MaterialTheme.colorScheme.surface + } + ) + ) { + Box( + modifier = Modifier.fillMaxSize() + ) { + Column( + modifier = Modifier.fillMaxSize() + ) { + // 2x2 grid of faces + val facesToShow = cluster.representativeFaces.take(4) + + Column( + modifier = Modifier.weight(1f) + ) { + // Top row (2 faces) + Row(modifier = Modifier.weight(1f)) { + facesToShow.getOrNull(0)?.let { face -> + FaceThumbnail( + imageUri = face.imageUri, + enabled = canTrain, + modifier = Modifier.weight(1f) + ) + } ?: EmptyFaceSlot(Modifier.weight(1f)) + + facesToShow.getOrNull(1)?.let { face -> + FaceThumbnail( + imageUri = face.imageUri, + enabled = canTrain, + modifier = Modifier.weight(1f) + ) + } ?: EmptyFaceSlot(Modifier.weight(1f)) + } + + // Bottom row (2 faces) + Row(modifier = Modifier.weight(1f)) { + facesToShow.getOrNull(2)?.let { face -> + FaceThumbnail( + imageUri = face.imageUri, + enabled = canTrain, + modifier = Modifier.weight(1f) + ) + } ?: EmptyFaceSlot(Modifier.weight(1f)) + + facesToShow.getOrNull(3)?.let { face -> + FaceThumbnail( + imageUri = face.imageUri, + enabled = canTrain, + modifier = Modifier.weight(1f) + ) + } ?: EmptyFaceSlot(Modifier.weight(1f)) + } + } + + // Footer with photo count + Surface( + modifier = Modifier.fillMaxWidth(), + color = if (canTrain) { + MaterialTheme.colorScheme.primaryContainer + } else { + MaterialTheme.colorScheme.surfaceVariant + } + ) { + Row( + modifier = Modifier.padding(12.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween + ) { + Text( + text = "${cluster.photoCount} photos", + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.SemiBold, + color = if (canTrain) { + MaterialTheme.colorScheme.onPrimaryContainer + } else { + MaterialTheme.colorScheme.onSurfaceVariant + } + ) + } + } + } + + // Quality badge overlay + QualityBadge( + qualityTier = qualityTier, + canTrain = canTrain, + modifier = Modifier + .align(Alignment.TopEnd) + .padding(8.dp) + ) + } + } +} + +/** + * Quality badge indicator + */ +@Composable +private fun QualityBadge( + qualityTier: ClusterQualityTier, + canTrain: Boolean, + modifier: Modifier = Modifier +) { + val (backgroundColor, iconColor, icon) = when (qualityTier) { + ClusterQualityTier.EXCELLENT -> Triple( + Color(0xFF1B5E20), + Color.White, + Icons.Default.Check + ) + ClusterQualityTier.GOOD -> Triple( + Color(0xFF2E7D32), + Color.White, + Icons.Default.Check + ) + ClusterQualityTier.POOR -> Triple( + Color(0xFFD32F2F), + Color.White, + Icons.Default.Warning + ) + } + + Surface( + modifier = modifier, + shape = CircleShape, + color = backgroundColor, + shadowElevation = 2.dp + ) { + Box( + modifier = Modifier + .size(32.dp) + .padding(6.dp), + contentAlignment = Alignment.Center + ) { + Icon( + imageVector = icon, + contentDescription = qualityTier.name, + tint = iconColor, + modifier = Modifier.size(20.dp) + ) + } + } +} + +@Composable +private fun FaceThumbnail( + imageUri: String, + enabled: Boolean, + modifier: Modifier = Modifier +) { + Box(modifier = modifier) { + AsyncImage( + model = Uri.parse(imageUri), + contentDescription = "Face", + modifier = Modifier + .fillMaxSize() + .border( + width = 0.5.dp, + color = MaterialTheme.colorScheme.outline.copy(alpha = 0.3f) + ), + contentScale = ContentScale.Crop, + alpha = if (enabled) 1f else 0.6f + ) + } +} + +@Composable +private fun EmptyFaceSlot(modifier: Modifier = Modifier) { + Box( + modifier = modifier + .fillMaxSize() + .background(MaterialTheme.colorScheme.surfaceVariant) + .border( + width = 0.5.dp, + color = MaterialTheme.colorScheme.outline.copy(alpha = 0.3f) + ) + ) +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverpeoplescreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverpeoplescreen.kt new file mode 100644 index 0000000..4609d00 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverpeoplescreen.kt @@ -0,0 +1,753 @@ +package com.placeholder.sherpai2.ui.discover + +import androidx.compose.foundation.layout.* +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.Person +import androidx.compose.material.icons.filled.Refresh +import androidx.compose.material.icons.filled.Storage +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.unit.dp +import androidx.hilt.navigation.compose.hiltViewModel +import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer + +/** + * DiscoverPeopleScreen - WITH SETTINGS SUPPORT + * + * NEW FEATURES: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * ✅ Discovery settings card with quality sliders + * ✅ Retry button in naming dialog + * ✅ Cache building progress UI + * ✅ Settings affect clustering behavior + */ +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun DiscoverPeopleScreen( + viewModel: DiscoverPeopleViewModel = hiltViewModel(), + onNavigateBack: () -> Unit = {} +) { + val uiState by viewModel.uiState.collectAsState() + val qualityAnalyzer = remember { ClusterQualityAnalyzer() } + + // NEW: Settings state + var settings by remember { mutableStateOf(DiscoverySettings.DEFAULT) } + + Box(modifier = Modifier.fillMaxSize()) { + when (val state = uiState) { + // ===== IDLE STATE (START HERE) ===== + is DiscoverUiState.Idle -> { + IdleStateWithSettings( + settings = settings, + onSettingsChange = { settings = it }, + onStartDiscovery = { viewModel.startDiscovery(settings) } + ) + } + + // ===== NEW: BUILDING CACHE (FIRST-TIME SETUP) ===== + is DiscoverUiState.BuildingCache -> { + BuildingCacheContent( + progress = state.progress, + total = state.total, + message = state.message + ) + } + + // ===== CLUSTERING IN PROGRESS ===== + is DiscoverUiState.Clustering -> { + ClusteringProgressContent( + progress = state.progress, + total = state.total, + message = state.message + ) + } + + // ===== CLUSTERS READY FOR NAMING ===== + is DiscoverUiState.NamingReady -> { + ClusterGridScreen( + result = state.result, + onSelectCluster = { cluster -> + viewModel.selectCluster(cluster) + }, + qualityAnalyzer = qualityAnalyzer + ) + } + + // ===== ANALYZING CLUSTER QUALITY ===== + is DiscoverUiState.AnalyzingCluster -> { + LoadingContent(message = "Analyzing cluster quality...") + } + + // ===== NAMING A CLUSTER (SHOW DIALOG) ===== + is DiscoverUiState.NamingCluster -> { + ClusterGridScreen( + result = state.result, + onSelectCluster = { /* Disabled while dialog open */ }, + qualityAnalyzer = qualityAnalyzer + ) + + NamingDialog( + cluster = state.selectedCluster, + suggestedSiblings = state.suggestedSiblings, + onConfirm = { name, dateOfBirth, isChild, selectedSiblings -> + viewModel.confirmClusterName( + cluster = state.selectedCluster, + name = name, + dateOfBirth = dateOfBirth, + isChild = isChild, + selectedSiblings = selectedSiblings + ) + }, + onRetry = { viewModel.retryDiscovery() }, // NEW! + onDismiss = { + viewModel.cancelNaming() + }, + qualityAnalyzer = qualityAnalyzer + ) + } + + // ===== TRAINING IN PROGRESS ===== + is DiscoverUiState.Training -> { + TrainingProgressContent( + stage = state.stage, + progress = state.progress, + total = state.total + ) + } + + // ===== VALIDATION PREVIEW ===== + is DiscoverUiState.ValidationPreview -> { + ValidationPreviewScreen( + personName = state.personName, + validationResult = state.validationResult, + onMarkFeedback = { feedbackMap -> + viewModel.submitFeedback(state.cluster, feedbackMap) + }, + onRequestRefinement = { + viewModel.requestRefinement(state.cluster) + }, + onApprove = { + viewModel.acceptValidationAndFinish() + }, + onReject = { + viewModel.requestRefinement(state.cluster) + } + ) + } + + // ===== REFINEMENT NEEDED ===== + is DiscoverUiState.RefinementNeeded -> { + RefinementNeededContent( + recommendation = state.recommendation, + currentIteration = state.currentIteration, + onRefine = { + viewModel.requestRefinement(state.cluster) + }, + onSkip = { + viewModel.skipRefinement() + } + ) + } + + // ===== REFINING IN PROGRESS ===== + is DiscoverUiState.Refining -> { + RefiningProgressContent( + iteration = state.iteration, + message = state.message + ) + } + + // ===== COMPLETE ===== + is DiscoverUiState.Complete -> { + CompleteStateContent( + message = state.message, + onDone = onNavigateBack, + onDiscoverMore = { viewModel.retryDiscovery() } + ) + } + + // ===== NO PEOPLE FOUND ===== + is DiscoverUiState.NoPeopleFound -> { + ErrorStateContent( + title = "No People Found", + message = state.message, + onRetry = { viewModel.retryDiscovery() }, + onBack = onNavigateBack + ) + } + + // ===== ERROR ===== + is DiscoverUiState.Error -> { + ErrorStateContent( + title = "Error", + message = state.message, + onRetry = { viewModel.retryDiscovery() }, + onBack = onNavigateBack + ) + } + } + } +} + +// ═══════════════════════════════════════════════════════════ +// IDLE STATE WITH SETTINGS +// ═══════════════════════════════════════════════════════════ + +@Composable +private fun IdleStateWithSettings( + settings: DiscoverySettings, + onSettingsChange: (DiscoverySettings) -> Unit, + onStartDiscovery: () -> Unit +) { + Column( + modifier = Modifier + .fillMaxSize() + .padding(24.dp), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center + ) { + Icon( + imageVector = Icons.Default.Person, + contentDescription = null, + modifier = Modifier.size(120.dp), + tint = MaterialTheme.colorScheme.primary + ) + + Spacer(modifier = Modifier.height(32.dp)) + + Text( + text = "Automatically find and organize people in your photo library", + style = MaterialTheme.typography.headlineSmall, + textAlign = TextAlign.Center, + color = MaterialTheme.colorScheme.onSurface + ) + + Spacer(modifier = Modifier.height(32.dp)) + + // NEW: Settings Card + DiscoverySettingsCard( + settings = settings, + onSettingsChange = onSettingsChange + ) + + Spacer(modifier = Modifier.height(24.dp)) + + Button( + onClick = onStartDiscovery, + modifier = Modifier + .fillMaxWidth() + .height(56.dp) + ) { + Text( + text = "Start Discovery", + style = MaterialTheme.typography.titleMedium + ) + } + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + text = "This will analyze faces in your photos and group similar faces together", + style = MaterialTheme.typography.bodySmall, + textAlign = TextAlign.Center, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } +} + +// ═══════════════════════════════════════════════════════════ +// BUILDING CACHE CONTENT +// ═══════════════════════════════════════════════════════════ + +@Composable +private fun BuildingCacheContent( + progress: Int, + total: Int, + message: String +) { + Column( + modifier = Modifier + .fillMaxSize() + .padding(24.dp), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center + ) { + Icon( + imageVector = Icons.Default.Storage, + contentDescription = null, + modifier = Modifier.size(80.dp), + tint = MaterialTheme.colorScheme.primary + ) + + Spacer(modifier = Modifier.height(32.dp)) + + Text( + text = "Building Cache", + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold, + textAlign = TextAlign.Center + ) + + Spacer(modifier = Modifier.height(16.dp)) + + Card( + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.primaryContainer + ), + modifier = Modifier.fillMaxWidth() + ) { + Column( + modifier = Modifier.padding(16.dp), + horizontalAlignment = Alignment.CenterHorizontally + ) { + Text( + text = message, + style = MaterialTheme.typography.bodyMedium, + textAlign = TextAlign.Center, + color = MaterialTheme.colorScheme.onPrimaryContainer + ) + } + } + + Spacer(modifier = Modifier.height(24.dp)) + + if (total > 0) { + LinearProgressIndicator( + progress = { progress.toFloat() / total.toFloat() }, + modifier = Modifier + .fillMaxWidth() + .height(12.dp) + ) + + Spacer(modifier = Modifier.height(12.dp)) + + Text( + text = "$progress / $total photos analyzed", + style = MaterialTheme.typography.bodyLarge, + fontWeight = FontWeight.Medium, + color = MaterialTheme.colorScheme.primary + ) + + Spacer(modifier = Modifier.height(8.dp)) + + val percentComplete = (progress.toFloat() / total.toFloat() * 100).toInt() + Text( + text = "$percentComplete% complete", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } else { + CircularProgressIndicator( + modifier = Modifier.size(64.dp) + ) + } + + Spacer(modifier = Modifier.height(32.dp)) + + Card( + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.secondaryContainer + ), + modifier = Modifier.fillMaxWidth() + ) { + Column( + modifier = Modifier.padding(16.dp) + ) { + Text( + text = "ℹ️ What's happening?", + style = MaterialTheme.typography.titleSmall, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.onSecondaryContainer + ) + + Spacer(modifier = Modifier.height(8.dp)) + + Text( + text = "We're analyzing your photo library once to identify which photos contain faces. " + + "This speeds up future discoveries by 95%!\n\n" + + "This only happens once and will make all future discoveries instant.", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSecondaryContainer + ) + } + } + } +} + +// ═══════════════════════════════════════════════════════════ +// CLUSTERING PROGRESS +// ═══════════════════════════════════════════════════════════ + +@Composable +private fun ClusteringProgressContent( + progress: Int, + total: Int, + message: String +) { + Column( + modifier = Modifier + .fillMaxSize() + .padding(24.dp), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center + ) { + CircularProgressIndicator( + modifier = Modifier.size(64.dp) + ) + + Spacer(modifier = Modifier.height(32.dp)) + + Text( + text = message, + style = MaterialTheme.typography.titleMedium, + textAlign = TextAlign.Center + ) + + Spacer(modifier = Modifier.height(16.dp)) + + if (total > 0) { + LinearProgressIndicator( + progress = { progress.toFloat() / total.toFloat() }, + modifier = Modifier + .fillMaxWidth() + .height(8.dp) + ) + + Spacer(modifier = Modifier.height(8.dp)) + + Text( + text = "$progress / $total", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } +} + +// ═══════════════════════════════════════════════════════════ +// TRAINING PROGRESS +// ═══════════════════════════════════════════════════════════ + +@Composable +private fun TrainingProgressContent( + stage: String, + progress: Int, + total: Int +) { + Column( + modifier = Modifier + .fillMaxSize() + .padding(24.dp), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center + ) { + CircularProgressIndicator( + modifier = Modifier.size(64.dp) + ) + + Spacer(modifier = Modifier.height(32.dp)) + + Text( + text = stage, + style = MaterialTheme.typography.titleMedium, + textAlign = TextAlign.Center + ) + + if (total > 0) { + Spacer(modifier = Modifier.height(16.dp)) + + LinearProgressIndicator( + progress = { progress.toFloat() / total.toFloat() }, + modifier = Modifier + .fillMaxWidth() + .height(8.dp) + ) + + Spacer(modifier = Modifier.height(8.dp)) + + Text( + text = "$progress / $total", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } +} + +// ═══════════════════════════════════════════════════════════ +// REFINEMENT NEEDED +// ═══════════════════════════════════════════════════════════ + +@Composable +private fun RefinementNeededContent( + recommendation: com.placeholder.sherpai2.domain.clustering.RefinementRecommendation, + currentIteration: Int, + onRefine: () -> Unit, + onSkip: () -> Unit +) { + Column( + modifier = Modifier + .fillMaxSize() + .padding(24.dp), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center + ) { + Icon( + imageVector = Icons.Default.Person, + contentDescription = null, + modifier = Modifier.size(80.dp), + tint = MaterialTheme.colorScheme.primary + ) + + Spacer(modifier = Modifier.height(24.dp)) + + Text( + text = "Refinement Recommended", + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold + ) + + Spacer(modifier = Modifier.height(16.dp)) + + Card( + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.errorContainer + ) + ) { + Column( + modifier = Modifier.padding(16.dp) + ) { + Text( + text = recommendation.reason, + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onErrorContainer + ) + } + } + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + text = "Iteration: $currentIteration", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(32.dp)) + + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(12.dp) + ) { + OutlinedButton( + onClick = onSkip, + modifier = Modifier.weight(1f) + ) { + Text("Skip") + } + + Button( + onClick = onRefine, + modifier = Modifier.weight(1f) + ) { + Text("Refine Cluster") + } + } + } +} + +// ═══════════════════════════════════════════════════════════ +// REFINING PROGRESS +// ═══════════════════════════════════════════════════════════ + +@Composable +private fun RefiningProgressContent( + iteration: Int, + message: String +) { + Column( + modifier = Modifier + .fillMaxSize() + .padding(24.dp), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center + ) { + CircularProgressIndicator( + modifier = Modifier.size(64.dp) + ) + + Spacer(modifier = Modifier.height(32.dp)) + + Text( + text = "Refining Cluster", + style = MaterialTheme.typography.titleLarge, + fontWeight = FontWeight.Bold + ) + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + text = message, + style = MaterialTheme.typography.bodyMedium, + textAlign = TextAlign.Center, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(8.dp)) + + Text( + text = "Iteration $iteration", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } +} + +// ═══════════════════════════════════════════════════════════ +// LOADING CONTENT +// ═══════════════════════════════════════════════════════════ + +@Composable +private fun LoadingContent(message: String) { + Column( + modifier = Modifier.fillMaxSize(), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center + ) { + CircularProgressIndicator() + Spacer(modifier = Modifier.height(16.dp)) + Text(text = message) + } +} + +// ═══════════════════════════════════════════════════════════ +// COMPLETE STATE +// ═══════════════════════════════════════════════════════════ + +@Composable +private fun CompleteStateContent( + message: String, + onDone: () -> Unit, + onDiscoverMore: () -> Unit +) { + Column( + modifier = Modifier + .fillMaxSize() + .padding(24.dp), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center + ) { + Text( + text = "🎉", + style = MaterialTheme.typography.displayLarge + ) + + Spacer(modifier = Modifier.height(24.dp)) + + Text( + text = "Success!", + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold + ) + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + text = message, + style = MaterialTheme.typography.bodyLarge, + textAlign = TextAlign.Center, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(32.dp)) + + Button( + onClick = onDone, + modifier = Modifier.fillMaxWidth() + ) { + Text("Done") + } + + Spacer(modifier = Modifier.height(12.dp)) + + OutlinedButton( + onClick = onDiscoverMore, + modifier = Modifier.fillMaxWidth() + ) { + Icon( + imageVector = Icons.Default.Refresh, + contentDescription = null, + modifier = Modifier.size(20.dp) + ) + Spacer(Modifier.width(8.dp)) + Text("Discover More People") + } + } +} + +// ═══════════════════════════════════════════════════════════ +// ERROR STATE +// ═══════════════════════════════════════════════════════════ + +@Composable +private fun ErrorStateContent( + title: String, + message: String, + onRetry: () -> Unit, + onBack: () -> Unit +) { + Column( + modifier = Modifier + .fillMaxSize() + .padding(24.dp), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center + ) { + Text( + text = "⚠️", + style = MaterialTheme.typography.displayLarge + ) + + Spacer(modifier = Modifier.height(24.dp)) + + Text( + text = title, + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold + ) + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + text = message, + style = MaterialTheme.typography.bodyLarge, + textAlign = TextAlign.Center, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(32.dp)) + + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(12.dp) + ) { + OutlinedButton( + onClick = onBack, + modifier = Modifier.weight(1f) + ) { + Text("Back") + } + + Button( + onClick = onRetry, + modifier = Modifier.weight(1f) + ) { + Text("Retry") + } + } + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverpeopleviewmodel.kt b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverpeopleviewmodel.kt new file mode 100644 index 0000000..95a748c --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverpeopleviewmodel.kt @@ -0,0 +1,523 @@ +package com.placeholder.sherpai2.ui.discover + +import android.content.Context +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import androidx.work.* +import com.placeholder.sherpai2.data.local.dao.FaceCacheDao +import com.placeholder.sherpai2.data.local.entity.FeedbackType +import com.placeholder.sherpai2.domain.clustering.* +import com.placeholder.sherpai2.domain.training.ClusterTrainingService +import com.placeholder.sherpai2.domain.validation.ValidationScanResult +import com.placeholder.sherpai2.domain.validation.ValidationScanService +import com.placeholder.sherpai2.workers.CachePopulationWorker +import dagger.hilt.android.lifecycle.HiltViewModel +import dagger.hilt.android.qualifiers.ApplicationContext +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.launch +import javax.inject.Inject + +@HiltViewModel +class DiscoverPeopleViewModel @Inject constructor( + @ApplicationContext private val context: Context, + private val clusteringService: FaceClusteringService, + private val trainingService: ClusterTrainingService, + private val validationService: ValidationScanService, + private val refinementService: ClusterRefinementService, + private val faceCacheDao: FaceCacheDao +) : ViewModel() { + + private val _uiState = MutableStateFlow(DiscoverUiState.Idle) + val uiState: StateFlow = _uiState.asStateFlow() + + private val namedClusterIds = mutableSetOf() + private var currentIterationCount = 0 + + // NEW: Store settings for use after cache population + private var lastUsedSettings: DiscoverySettings = DiscoverySettings.DEFAULT + + private val workManager = WorkManager.getInstance(context) + private var cacheWorkRequestId: java.util.UUID? = null + + /** + * ENHANCED: Check cache before starting Discovery (with settings support) + */ + fun startDiscovery(settings: DiscoverySettings = DiscoverySettings.DEFAULT) { + lastUsedSettings = settings // Store for later use + + // LOG SETTINGS + android.util.Log.d("DiscoverVM", "═══════════════════════════════════════") + android.util.Log.d("DiscoverVM", "🎛️ DISCOVERY SETTINGS") + android.util.Log.d("DiscoverVM", "═══════════════════════════════════════") + android.util.Log.d("DiscoverVM", "Min Face Size: ${settings.minFaceSize} (${(settings.minFaceSize * 100).toInt()}%)") + android.util.Log.d("DiscoverVM", "Min Quality: ${settings.minQuality} (${(settings.minQuality * 100).toInt()}%)") + android.util.Log.d("DiscoverVM", "Epsilon: ${settings.epsilon}") + android.util.Log.d("DiscoverVM", "Is Default: ${settings == DiscoverySettings.DEFAULT}") + android.util.Log.d("DiscoverVM", "═══════════════════════════════════════") + viewModelScope.launch { + try { + namedClusterIds.clear() + currentIterationCount = 0 + + // Check cache status + val cacheStats = faceCacheDao.getCacheStats() + + android.util.Log.d("DiscoverVM", "Cache check: totalFaces=${cacheStats.totalFaces}") + + if (cacheStats.totalFaces == 0) { + // Cache empty - need to build it first + android.util.Log.d("DiscoverVM", "Cache empty, starting cache population") + + _uiState.value = DiscoverUiState.BuildingCache( + progress = 0, + total = 100, + message = "First-time setup: Building face cache...\n\nThis is a one-time process that will take 5-10 minutes." + ) + + startCachePopulation() + } else { + android.util.Log.d("DiscoverVM", "Cache exists (${cacheStats.totalFaces} faces), proceeding to Discovery") + + // Cache exists - proceed to Discovery + _uiState.value = DiscoverUiState.Clustering(0, 100, "Starting discovery...") + executeDiscovery() + } + } catch (e: Exception) { + android.util.Log.e("DiscoverVM", "Error checking cache", e) + _uiState.value = DiscoverUiState.Error( + "Failed to check cache: ${e.message}" + ) + } + } + } + + /** + * Start cache population worker + */ + private fun startCachePopulation() { + viewModelScope.launch { + android.util.Log.d("DiscoverVM", "Enqueuing CachePopulationWorker") + + val workRequest = OneTimeWorkRequestBuilder() + .setConstraints( + Constraints.Builder() + .setRequiresCharging(false) + .setRequiresBatteryNotLow(false) + .build() + ) + .build() + + cacheWorkRequestId = workRequest.id + + // Enqueue work + workManager.enqueueUniqueWork( + CachePopulationWorker.WORK_NAME, + ExistingWorkPolicy.REPLACE, + workRequest + ) + + // Observe progress + workManager.getWorkInfoByIdLiveData(workRequest.id).observeForever { workInfo -> + android.util.Log.d("DiscoverVM", "Worker state: ${workInfo?.state}") + + when (workInfo?.state) { + WorkInfo.State.RUNNING -> { + val current = workInfo.progress.getInt( + CachePopulationWorker.KEY_PROGRESS_CURRENT, + 0 + ) + val total = workInfo.progress.getInt( + CachePopulationWorker.KEY_PROGRESS_TOTAL, + 100 + ) + + _uiState.value = DiscoverUiState.BuildingCache( + progress = current, + total = total, + message = "Building face cache...\n\nAnalyzing $current of $total photos\n\nThis improves future Discovery performance by 95%!" + ) + } + + WorkInfo.State.SUCCEEDED -> { + val cachedCount = workInfo.outputData.getInt( + CachePopulationWorker.KEY_CACHED_COUNT, + 0 + ) + + android.util.Log.d("DiscoverVM", "Cache population complete: $cachedCount faces") + + _uiState.value = DiscoverUiState.BuildingCache( + progress = 100, + total = 100, + message = "Cache complete! Found $cachedCount faces.\n\nStarting Discovery now..." + ) + + // Automatically start Discovery after cache is ready + viewModelScope.launch { + kotlinx.coroutines.delay(1000) + _uiState.value = DiscoverUiState.Clustering(0, 100, "Starting discovery...") + executeDiscovery() + } + } + + WorkInfo.State.FAILED -> { + val error = workInfo.outputData.getString("error") + android.util.Log.e("DiscoverVM", "Cache population failed: $error") + + _uiState.value = DiscoverUiState.Error( + "Cache building failed: ${error ?: "Unknown error"}\n\n" + + "Discovery will use slower full-scan mode.\n\n" + + "You can retry cache building later." + ) + } + + else -> { + // ENQUEUED, BLOCKED, CANCELLED + } + } + } + } + } + + /** + * Execute the actual Discovery clustering (with settings support) + */ + private suspend fun executeDiscovery() { + try { + // LOG WHICH PATH WE'RE TAKING + android.util.Log.d("DiscoverVM", "═══════════════════════════════════════") + android.util.Log.d("DiscoverVM", "🚀 EXECUTING DISCOVERY") + android.util.Log.d("DiscoverVM", "═══════════════════════════════════════") + + // Use discoverPeopleWithSettings if settings are non-default + val result = if (lastUsedSettings == DiscoverySettings.DEFAULT) { + android.util.Log.d("DiscoverVM", "Using DEFAULT settings path") + android.util.Log.d("DiscoverVM", "Calling: clusteringService.discoverPeople()") + + // Use regular method for default settings + clusteringService.discoverPeople( + strategy = ClusteringStrategy.PREMIUM_SOLO_ONLY, + onProgress = { current: Int, total: Int, message: String -> + _uiState.value = DiscoverUiState.Clustering(current, total, message) + } + ) + } else { + android.util.Log.d("DiscoverVM", "Using CUSTOM settings path") + android.util.Log.d("DiscoverVM", "Settings: minFaceSize=${lastUsedSettings.minFaceSize}, minQuality=${lastUsedSettings.minQuality}, epsilon=${lastUsedSettings.epsilon}") + android.util.Log.d("DiscoverVM", "Calling: clusteringService.discoverPeopleWithSettings()") + + // Use settings-aware method + clusteringService.discoverPeopleWithSettings( + settings = lastUsedSettings, + onProgress = { current: Int, total: Int, message: String -> + _uiState.value = DiscoverUiState.Clustering(current, total, message) + } + ) + } + + android.util.Log.d("DiscoverVM", "Discovery complete: ${result.clusters.size} clusters found") + android.util.Log.d("DiscoverVM", "═══════════════════════════════════════") + + if (result.errorMessage != null) { + _uiState.value = DiscoverUiState.Error(result.errorMessage) + return + } + + if (result.clusters.isEmpty()) { + _uiState.value = DiscoverUiState.NoPeopleFound( + result.errorMessage + ?: "No people clusters found.\n\nTry:\n• Adding more solo photos\n• Ensuring photos are clear\n• Having 6+ photos per person" + ) + } else { + _uiState.value = DiscoverUiState.NamingReady(result) + } + } catch (e: Exception) { + android.util.Log.e("DiscoverVM", "Discovery failed", e) + _uiState.value = DiscoverUiState.Error(e.message ?: "Failed to discover people") + } + } + + fun selectCluster(cluster: FaceCluster) { + val currentState = _uiState.value + if (currentState is DiscoverUiState.NamingReady) { + _uiState.value = DiscoverUiState.NamingCluster( + result = currentState.result, + selectedCluster = cluster, + suggestedSiblings = currentState.result.clusters.filter { + it.clusterId in cluster.potentialSiblings + } + ) + } + } + + fun confirmClusterName( + cluster: FaceCluster, + name: String, + dateOfBirth: Long?, + isChild: Boolean, + selectedSiblings: List + ) { + viewModelScope.launch { + try { + val currentState = _uiState.value + if (currentState !is DiscoverUiState.NamingCluster) return@launch + + _uiState.value = DiscoverUiState.AnalyzingCluster + + _uiState.value = DiscoverUiState.Training( + stage = "Creating face model for $name...", + progress = 0, + total = cluster.faces.size + ) + + val personId = trainingService.trainFromCluster( + cluster = cluster, + name = name, + dateOfBirth = dateOfBirth, + isChild = isChild, + siblingClusterIds = selectedSiblings, + onProgress = { current: Int, total: Int, message: String -> + _uiState.value = DiscoverUiState.Training(message, current, total) + } + ) + + _uiState.value = DiscoverUiState.Training( + stage = "Running validation scan...", + progress = 0, + total = 100 + ) + + val validationResult = validationService.performValidationScan( + personId = personId, + onProgress = { current: Int, total: Int -> + _uiState.value = DiscoverUiState.Training( + stage = "Validating model quality...", + progress = current, + total = total + ) + } + ) + + _uiState.value = DiscoverUiState.ValidationPreview( + personId = personId, + personName = name, + cluster = cluster, + validationResult = validationResult + ) + + } catch (e: Exception) { + _uiState.value = DiscoverUiState.Error(e.message ?: "Failed to create person") + } + } + } + + fun submitFeedback( + cluster: FaceCluster, + feedbackMap: Map + ) { + viewModelScope.launch { + try { + val faceFeedbackMap = cluster.faces + .associateWith { face -> + feedbackMap[face.imageId] ?: FeedbackType.UNCERTAIN + } + + val originalConfidences = cluster.faces.associateWith { it.confidence } + + refinementService.storeFeedback( + cluster = cluster, + feedbackMap = faceFeedbackMap, + originalConfidences = originalConfidences + ) + + val recommendation = refinementService.shouldRefineCluster(cluster) + + if (recommendation.shouldRefine) { + _uiState.value = DiscoverUiState.RefinementNeeded( + cluster = cluster, + recommendation = recommendation, + currentIteration = currentIterationCount + ) + } + } catch (e: Exception) { + _uiState.value = DiscoverUiState.Error( + "Failed to process feedback: ${e.message}" + ) + } + } + } + + fun requestRefinement(cluster: FaceCluster) { + viewModelScope.launch { + try { + currentIterationCount++ + + _uiState.value = DiscoverUiState.Refining( + iteration = currentIterationCount, + message = "Removing incorrect faces and re-clustering..." + ) + + val refinementResult = refinementService.refineCluster( + cluster = cluster, + iterationNumber = currentIterationCount + ) + + if (!refinementResult.success || refinementResult.refinedCluster == null) { + _uiState.value = DiscoverUiState.Error( + refinementResult.errorMessage + ?: "Failed to refine cluster. Please try manual training." + ) + return@launch + } + + val currentState = _uiState.value + if (currentState is DiscoverUiState.RefinementNeeded) { + confirmClusterName( + cluster = refinementResult.refinedCluster, + name = currentState.cluster.representativeFaces.first().imageId, + dateOfBirth = null, + isChild = false, + selectedSiblings = emptyList() + ) + } + + } catch (e: Exception) { + _uiState.value = DiscoverUiState.Error( + "Refinement failed: ${e.message}" + ) + } + } + } + + fun approveValidationAndScan(personId: String, personName: String) { + viewModelScope.launch { + try { + _uiState.value = DiscoverUiState.Complete( + message = "Successfully created model for \"$personName\"!\n\n" + + "Full library scan has been queued in the background.\n\n" + + "✅ ${currentIterationCount} refinement iterations completed" + ) + } catch (e: Exception) { + _uiState.value = DiscoverUiState.Error(e.message ?: "Failed to start library scan") + } + } + } + + fun rejectValidationAndImprove() { + _uiState.value = DiscoverUiState.Error( + "Please add more training photos and try again.\n\n" + + "(Feature coming: ability to add photos to existing model)" + ) + } + + fun cancelNaming() { + val currentState = _uiState.value + if (currentState is DiscoverUiState.NamingCluster) { + _uiState.value = DiscoverUiState.NamingReady(result = currentState.result) + } + } + + fun reset() { + cacheWorkRequestId?.let { workId -> + workManager.cancelWorkById(workId) + } + + _uiState.value = DiscoverUiState.Idle + namedClusterIds.clear() + currentIterationCount = 0 + } + + /** + * Retry discovery (returns to idle state) + */ + fun retryDiscovery() { + _uiState.value = DiscoverUiState.Idle + } + + /** + * Accept validation results and finish + */ + fun acceptValidationAndFinish() { + _uiState.value = DiscoverUiState.Complete( + "Person created successfully!" + ) + } + + /** + * Skip refinement and finish + */ + fun skipRefinement() { + _uiState.value = DiscoverUiState.Complete( + "Person created successfully!" + ) + } +} + +/** + * UI States - ENHANCED with BuildingCache state + */ +sealed class DiscoverUiState { + object Idle : DiscoverUiState() + + data class BuildingCache( + val progress: Int, + val total: Int, + val message: String + ) : DiscoverUiState() + + data class Clustering( + val progress: Int, + val total: Int, + val message: String + ) : DiscoverUiState() + + data class NamingReady( + val result: ClusteringResult + ) : DiscoverUiState() + + data class NamingCluster( + val result: ClusteringResult, + val selectedCluster: FaceCluster, + val suggestedSiblings: List + ) : DiscoverUiState() + + object AnalyzingCluster : DiscoverUiState() + + data class Training( + val stage: String, + val progress: Int, + val total: Int + ) : DiscoverUiState() + + data class ValidationPreview( + val personId: String, + val personName: String, + val cluster: FaceCluster, + val validationResult: ValidationScanResult + ) : DiscoverUiState() + + data class RefinementNeeded( + val cluster: FaceCluster, + val recommendation: RefinementRecommendation, + val currentIteration: Int + ) : DiscoverUiState() + + data class Refining( + val iteration: Int, + val message: String + ) : DiscoverUiState() + + data class Complete( + val message: String + ) : DiscoverUiState() + + data class NoPeopleFound( + val message: String + ) : DiscoverUiState() + + data class Error( + val message: String + ) : DiscoverUiState() +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverysettingscard.kt b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverysettingscard.kt new file mode 100644 index 0000000..21e7175 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverysettingscard.kt @@ -0,0 +1,309 @@ +package com.placeholder.sherpai2.ui.discover + +import androidx.compose.animation.AnimatedVisibility +import androidx.compose.animation.expandVertically +import androidx.compose.animation.shrinkVertically +import androidx.compose.foundation.layout.* +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.* +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.unit.dp + +/** + * DiscoverySettingsCard - Quality control sliders + * + * Allows tuning without dropping quality: + * - Face size threshold (bigger = more strict) + * - Quality score threshold (higher = better faces) + * - Clustering strictness (tighter = more clusters) + */ +@Composable +fun DiscoverySettingsCard( + settings: DiscoverySettings, + onSettingsChange: (DiscoverySettings) -> Unit, + modifier: Modifier = Modifier +) { + var expanded by remember { mutableStateOf(false) } + + Card( + modifier = modifier.fillMaxWidth(), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.surfaceVariant + ) + ) { + Column( + modifier = Modifier.fillMaxWidth() + ) { + // Header - Always visible + Row( + modifier = Modifier + .fillMaxWidth() + .padding(16.dp), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Row( + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = Icons.Default.Tune, + contentDescription = null, + tint = MaterialTheme.colorScheme.primary + ) + Column { + Text( + text = "Quality Settings", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Bold + ) + Text( + text = if (expanded) "Hide settings" else "Tap to adjust", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + + IconButton(onClick = { expanded = !expanded }) { + Icon( + imageVector = if (expanded) Icons.Default.ExpandLess + else Icons.Default.ExpandMore, + contentDescription = if (expanded) "Collapse" else "Expand" + ) + } + } + + // Settings - Expandable + AnimatedVisibility( + visible = expanded, + enter = expandVertically(), + exit = shrinkVertically() + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 16.dp) + .padding(bottom = 16.dp), + verticalArrangement = Arrangement.spacedBy(20.dp) + ) { + HorizontalDivider() + + // Face Size Slider + QualitySlider( + title = "Minimum Face Size", + description = "Smaller = more faces, larger = higher quality", + currentValue = "${(settings.minFaceSize * 100).toInt()}%", + value = settings.minFaceSize, + onValueChange = { onSettingsChange(settings.copy(minFaceSize = it)) }, + valueRange = 0.02f..0.08f, + icon = Icons.Default.ZoomIn + ) + + // Quality Score Slider + QualitySlider( + title = "Quality Threshold", + description = "Lower = more faces, higher = better quality", + currentValue = "${(settings.minQuality * 100).toInt()}%", + value = settings.minQuality, + onValueChange = { onSettingsChange(settings.copy(minQuality = it)) }, + valueRange = 0.4f..0.8f, + icon = Icons.Default.HighQuality + ) + + // Clustering Strictness + QualitySlider( + title = "Clustering Strictness", + description = when { + settings.epsilon < 0.20f -> "Very strict (more clusters)" + settings.epsilon > 0.25f -> "Loose (fewer clusters)" + else -> "Balanced" + }, + currentValue = when { + settings.epsilon < 0.20f -> "Strict" + settings.epsilon > 0.25f -> "Loose" + else -> "Normal" + }, + value = settings.epsilon, + onValueChange = { onSettingsChange(settings.copy(epsilon = it)) }, + valueRange = 0.16f..0.28f, + icon = Icons.Default.Category + ) + + HorizontalDivider() + + // Info Card + InfoCard( + text = "These settings control face quality, not photo type. " + + "Group photos are included - we extract the best face from each." + ) + + // Preset Buttons + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { + OutlinedButton( + onClick = { onSettingsChange(DiscoverySettings.STRICT) }, + modifier = Modifier.weight(1f) + ) { + Text("High Quality", style = MaterialTheme.typography.bodySmall) + } + + Button( + onClick = { onSettingsChange(DiscoverySettings.DEFAULT) }, + modifier = Modifier.weight(1f) + ) { + Text("Balanced", style = MaterialTheme.typography.bodySmall) + } + + OutlinedButton( + onClick = { onSettingsChange(DiscoverySettings.LOOSE) }, + modifier = Modifier.weight(1f) + ) { + Text("More Faces", style = MaterialTheme.typography.bodySmall) + } + } + } + } + } + } +} + +/** + * Individual quality slider component + */ +@Composable +private fun QualitySlider( + title: String, + description: String, + currentValue: String, + value: Float, + onValueChange: (Float) -> Unit, + valueRange: ClosedFloatingPointRange, + icon: androidx.compose.ui.graphics.vector.ImageVector +) { + Column( + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + // Header + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Row( + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.CenterVertically, + modifier = Modifier.weight(1f) + ) { + Icon( + imageVector = icon, + contentDescription = null, + tint = MaterialTheme.colorScheme.primary, + modifier = Modifier.size(20.dp) + ) + Text( + text = title, + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.Medium + ) + } + + Surface( + shape = MaterialTheme.shapes.small, + color = MaterialTheme.colorScheme.primaryContainer + ) { + Text( + text = currentValue, + modifier = Modifier.padding(horizontal = 12.dp, vertical = 4.dp), + style = MaterialTheme.typography.labelLarge, + color = MaterialTheme.colorScheme.onPrimaryContainer, + fontWeight = FontWeight.Bold + ) + } + } + + // Description + Text( + text = description, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + // Slider + Slider( + value = value, + onValueChange = onValueChange, + valueRange = valueRange + ) + } +} + +/** + * Info card component + */ +@Composable +private fun InfoCard(text: String) { + Card( + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.5f) + ) + ) { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(12.dp), + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = Icons.Default.Info, + contentDescription = null, + tint = MaterialTheme.colorScheme.onSecondaryContainer, + modifier = Modifier.size(18.dp) + ) + Text( + text = text, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSecondaryContainer + ) + } + } +} + +/** + * Discovery settings data class + */ +data class DiscoverySettings( + val minFaceSize: Float = 0.03f, // 3% of image (balanced) + val minQuality: Float = 0.6f, // 60% quality (good) + val epsilon: Float = 0.22f // DBSCAN threshold (balanced) +) { + companion object { + // Balanced - Default recommended settings + val DEFAULT = DiscoverySettings( + minFaceSize = 0.03f, + minQuality = 0.6f, + epsilon = 0.22f + ) + + // Strict - High quality, fewer faces + val STRICT = DiscoverySettings( + minFaceSize = 0.05f, // 5% of image + minQuality = 0.7f, // 70% quality + epsilon = 0.18f // Tight clustering + ) + + // Loose - More faces, lower quality threshold + val LOOSE = DiscoverySettings( + minFaceSize = 0.02f, // 2% of image + minQuality = 0.5f, // 50% quality + epsilon = 0.26f // Loose clustering + ) + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Namingdialog.kt b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Namingdialog.kt new file mode 100644 index 0000000..c089193 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Namingdialog.kt @@ -0,0 +1,637 @@ +package com.placeholder.sherpai2.ui.discover + +import androidx.compose.foundation.background +import androidx.compose.foundation.border +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.lazy.LazyRow +import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.rememberScrollState +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.text.KeyboardActions +import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.foundation.verticalScroll +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.* +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.layout.ContentScale +import androidx.compose.ui.platform.LocalSoftwareKeyboardController +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.input.ImeAction +import androidx.compose.ui.text.input.KeyboardCapitalization +import androidx.compose.ui.text.input.KeyboardType +import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.unit.dp +import androidx.compose.ui.window.Dialog +import coil.compose.AsyncImage +import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer +import com.placeholder.sherpai2.domain.clustering.ClusterQualityTier +import com.placeholder.sherpai2.domain.clustering.FaceCluster +import java.text.SimpleDateFormat +import java.util.* + +/** + * NamingDialog - ENHANCED with Retry Button + * + * NEW FEATURE: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * - Added onRetry parameter + * - Shows retry button for poor quality clusters + * - Also shows secondary retry option for good clusters + * + * All existing features preserved: + * - Name input with validation + * - Child toggle with date of birth picker + * - Sibling cluster selection + * - Quality warnings display + * - Preview of representative faces + */ +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun NamingDialog( + cluster: FaceCluster, + suggestedSiblings: List, + onConfirm: (name: String, dateOfBirth: Long?, isChild: Boolean, selectedSiblings: List) -> Unit, + onRetry: () -> Unit = {}, // NEW: Retry with different settings + onDismiss: () -> Unit, + qualityAnalyzer: ClusterQualityAnalyzer = remember { ClusterQualityAnalyzer() } +) { + var name by remember { mutableStateOf("") } + var isChild by remember { mutableStateOf(false) } + var showDatePicker by remember { mutableStateOf(false) } + var dateOfBirth by remember { mutableStateOf(null) } + var selectedSiblingIds by remember { mutableStateOf(setOf()) } + + // Analyze cluster quality + val qualityResult = remember(cluster) { + qualityAnalyzer.analyzeCluster(cluster) + } + + val keyboardController = LocalSoftwareKeyboardController.current + val dateFormatter = remember { SimpleDateFormat("MMM dd, yyyy", Locale.getDefault()) } + + Dialog(onDismissRequest = onDismiss) { + Card( + modifier = Modifier + .fillMaxWidth() + .fillMaxHeight(0.9f), + shape = RoundedCornerShape(16.dp), + elevation = CardDefaults.cardElevation(defaultElevation = 8.dp) + ) { + Column( + modifier = Modifier + .fillMaxSize() + .verticalScroll(rememberScrollState()) + ) { + // Header + Surface( + color = MaterialTheme.colorScheme.primaryContainer + ) { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(16.dp), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Text( + text = "Name This Person", + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.onPrimaryContainer + ) + + IconButton(onClick = onDismiss) { + Icon( + imageVector = Icons.Default.Close, + contentDescription = "Close", + tint = MaterialTheme.colorScheme.onPrimaryContainer + ) + } + } + } + + Column( + modifier = Modifier.padding(16.dp) + ) { + // ════════════════════════════════════════ + // NEW: Poor Quality Warning with Retry + // ════════════════════════════════════════ + if (qualityResult.qualityTier == ClusterQualityTier.POOR) { + Card( + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.errorContainer + ), + modifier = Modifier.fillMaxWidth() + ) { + Column( + modifier = Modifier.padding(16.dp), + verticalArrangement = Arrangement.spacedBy(12.dp) + ) { + Row( + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + Icons.Default.Warning, + contentDescription = null, + tint = MaterialTheme.colorScheme.onErrorContainer + ) + Text( + text = "Poor Quality Cluster", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.onErrorContainer + ) + } + + Text( + text = "This cluster doesn't meet quality requirements:", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onErrorContainer + ) + + Column(verticalArrangement = Arrangement.spacedBy(4.dp)) { + qualityResult.warnings.forEach { warning -> + Row(horizontalArrangement = Arrangement.spacedBy(8.dp)) { + Text("•", color = MaterialTheme.colorScheme.onErrorContainer) + Text( + warning, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onErrorContainer + ) + } + } + } + + HorizontalDivider( + color = MaterialTheme.colorScheme.onErrorContainer.copy(alpha = 0.3f) + ) + + Button( + onClick = onRetry, + modifier = Modifier.fillMaxWidth(), + colors = ButtonDefaults.buttonColors( + containerColor = MaterialTheme.colorScheme.error, + contentColor = MaterialTheme.colorScheme.onError + ) + ) { + Icon(Icons.Default.Refresh, contentDescription = null) + Spacer(Modifier.width(8.dp)) + Text("Retry with Different Settings") + } + } + } + + Spacer(modifier = Modifier.height(16.dp)) + } else if (qualityResult.warnings.isNotEmpty()) { + // Minor warnings for good/excellent clusters + Card( + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.5f) + ) + ) { + Column( + modifier = Modifier.padding(12.dp), + verticalArrangement = Arrangement.spacedBy(4.dp) + ) { + qualityResult.warnings.take(3).forEach { warning -> + Row( + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.Top + ) { + Icon( + Icons.Default.Info, + contentDescription = null, + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.onSecondaryContainer + ) + Text( + warning, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSecondaryContainer + ) + } + } + } + } + + Spacer(modifier = Modifier.height(16.dp)) + } + + // Quality badge + Surface( + color = when (qualityResult.qualityTier) { + ClusterQualityTier.EXCELLENT -> Color(0xFF1B5E20) + ClusterQualityTier.GOOD -> Color(0xFF2E7D32) + ClusterQualityTier.POOR -> Color(0xFFD32F2F) + }, + shape = RoundedCornerShape(8.dp) + ) { + Row( + modifier = Modifier.padding(horizontal = 12.dp, vertical = 6.dp), + horizontalArrangement = Arrangement.spacedBy(4.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = when (qualityResult.qualityTier) { + ClusterQualityTier.EXCELLENT, ClusterQualityTier.GOOD -> Icons.Default.Check + ClusterQualityTier.POOR -> Icons.Default.Warning + }, + contentDescription = null, + tint = Color.White, + modifier = Modifier.size(16.dp) + ) + Text( + text = "${qualityResult.qualityTier.name} Quality", + style = MaterialTheme.typography.labelMedium, + color = Color.White, + fontWeight = FontWeight.SemiBold + ) + } + } + + Spacer(modifier = Modifier.height(16.dp)) + + // Stats + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceEvenly + ) { + Column(horizontalAlignment = Alignment.CenterHorizontally) { + Text( + text = "${qualityResult.soloPhotoCount}", + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold + ) + Text( + text = "Solo Photos", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + Column(horizontalAlignment = Alignment.CenterHorizontally) { + Text( + text = "${qualityResult.cleanFaceCount}", + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold + ) + Text( + text = "Clean Faces", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + Column(horizontalAlignment = Alignment.CenterHorizontally) { + Text( + text = "${(qualityResult.qualityScore * 100).toInt()}%", + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold + ) + Text( + text = "Quality", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + + Spacer(modifier = Modifier.height(24.dp)) + + // Representative faces preview + if (cluster.representativeFaces.isNotEmpty()) { + Text( + text = "Representative Faces", + style = MaterialTheme.typography.titleSmall, + fontWeight = FontWeight.SemiBold, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(8.dp)) + + LazyRow( + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { + items(cluster.representativeFaces.take(6)) { face -> + AsyncImage( + model = android.net.Uri.parse(face.imageUri), + contentDescription = null, + modifier = Modifier + .size(80.dp) + .clip(RoundedCornerShape(8.dp)) + .border( + 2.dp, + MaterialTheme.colorScheme.outline.copy(alpha = 0.2f), + RoundedCornerShape(8.dp) + ), + contentScale = ContentScale.Crop + ) + } + } + + Spacer(modifier = Modifier.height(20.dp)) + } + + // Name input + OutlinedTextField( + value = name, + onValueChange = { name = it }, + label = { Text("Name") }, + placeholder = { Text("e.g., Emma") }, + leadingIcon = { + Icon( + imageVector = Icons.Default.Person, + contentDescription = null + ) + }, + keyboardOptions = KeyboardOptions( + capitalization = KeyboardCapitalization.Words, + imeAction = ImeAction.Done + ), + keyboardActions = KeyboardActions( + onDone = { keyboardController?.hide() } + ), + singleLine = true, + modifier = Modifier.fillMaxWidth(), + enabled = qualityResult.canTrain + ) + + Spacer(modifier = Modifier.height(16.dp)) + + // Child toggle + Surface( + modifier = Modifier.fillMaxWidth(), + color = if (isChild) MaterialTheme.colorScheme.primaryContainer + else MaterialTheme.colorScheme.surfaceVariant, + shape = RoundedCornerShape(12.dp) + ) { + Row( + modifier = Modifier + .fillMaxWidth() + .clickable(enabled = qualityResult.canTrain) { isChild = !isChild } + .padding(16.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween + ) { + Row( + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = Icons.Default.Face, + contentDescription = null, + tint = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer + else MaterialTheme.colorScheme.onSurfaceVariant + ) + Spacer(modifier = Modifier.width(12.dp)) + Column { + Text( + text = "This is a child", + style = MaterialTheme.typography.bodyLarge, + fontWeight = FontWeight.Medium, + color = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer + else MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + text = "For age-appropriate filtering", + style = MaterialTheme.typography.bodySmall, + color = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f) + else MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f) + ) + } + } + + Switch( + checked = isChild, + onCheckedChange = null, // Handled by row click + enabled = qualityResult.canTrain + ) + } + } + + // Date of birth (if child) + if (isChild) { + Spacer(modifier = Modifier.height(12.dp)) + + OutlinedButton( + onClick = { showDatePicker = true }, + modifier = Modifier.fillMaxWidth(), + enabled = qualityResult.canTrain + ) { + Icon( + imageVector = Icons.Default.DateRange, + contentDescription = null + ) + Spacer(modifier = Modifier.width(8.dp)) + Text( + text = dateOfBirth?.let { dateFormatter.format(Date(it)) } + ?: "Set date of birth (optional)" + ) + } + } + + // Sibling selection + if (suggestedSiblings.isNotEmpty()) { + Spacer(modifier = Modifier.height(20.dp)) + + Text( + text = "Appears with", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.SemiBold + ) + + Text( + text = "Select siblings or family members", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(8.dp)) + + suggestedSiblings.forEach { sibling -> + SiblingSelectionItem( + cluster = sibling, + selected = sibling.clusterId in selectedSiblingIds, + onToggle = { + selectedSiblingIds = if (sibling.clusterId in selectedSiblingIds) { + selectedSiblingIds - sibling.clusterId + } else { + selectedSiblingIds + sibling.clusterId + } + }, + enabled = qualityResult.canTrain + ) + Spacer(modifier = Modifier.height(8.dp)) + } + } + + Spacer(modifier = Modifier.height(24.dp)) + + // ════════════════════════════════════════ + // Action buttons + // ════════════════════════════════════════ + if (qualityResult.qualityTier == ClusterQualityTier.POOR) { + // Poor quality - Cancel only (retry button is above) + OutlinedButton( + onClick = onDismiss, + modifier = Modifier.fillMaxWidth() + ) { + Text("Cancel") + } + } else { + // Good quality - Normal flow + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(12.dp) + ) { + OutlinedButton( + onClick = onDismiss, + modifier = Modifier.weight(1f) + ) { + Text("Cancel") + } + + Button( + onClick = { + if (name.isNotBlank()) { + onConfirm( + name.trim(), + dateOfBirth, + isChild, + selectedSiblingIds.toList() + ) + } + }, + modifier = Modifier.weight(1f), + enabled = name.isNotBlank() && qualityResult.canTrain + ) { + Icon( + imageVector = Icons.Default.Check, + contentDescription = null, + modifier = Modifier.size(20.dp) + ) + Spacer(modifier = Modifier.width(8.dp)) + Text("Create Model") + } + } + + // ════════════════════════════════════════ + // NEW: Secondary retry option + // ════════════════════════════════════════ + Spacer(modifier = Modifier.height(8.dp)) + + TextButton( + onClick = onRetry, + modifier = Modifier.fillMaxWidth() + ) { + Icon( + Icons.Default.Refresh, + contentDescription = null, + modifier = Modifier.size(16.dp) + ) + Spacer(Modifier.width(4.dp)) + Text( + "Try again with different settings", + style = MaterialTheme.typography.bodySmall + ) + } + } + } + } + } + } + + // Date picker dialog + if (showDatePicker) { + val datePickerState = rememberDatePickerState() + + DatePickerDialog( + onDismissRequest = { showDatePicker = false }, + confirmButton = { + TextButton( + onClick = { + dateOfBirth = datePickerState.selectedDateMillis + showDatePicker = false + } + ) { + Text("OK") + } + }, + dismissButton = { + TextButton(onClick = { showDatePicker = false }) { + Text("Cancel") + } + } + ) { + DatePicker(state = datePickerState) + } + } +} + +@Composable +private fun SiblingSelectionItem( + cluster: FaceCluster, + selected: Boolean, + onToggle: () -> Unit, + enabled: Boolean = true +) { + Surface( + modifier = Modifier.fillMaxWidth(), + color = if (selected) MaterialTheme.colorScheme.primaryContainer + else MaterialTheme.colorScheme.surfaceVariant, + shape = RoundedCornerShape(8.dp) + ) { + Row( + modifier = Modifier + .fillMaxWidth() + .clickable(enabled = enabled) { onToggle() } + .padding(12.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { + // Face preview + if (cluster.representativeFaces.isNotEmpty()) { + AsyncImage( + model = android.net.Uri.parse(cluster.representativeFaces.first().imageUri), + contentDescription = null, + modifier = Modifier + .size(48.dp) + .clip(CircleShape) + .border(2.dp, MaterialTheme.colorScheme.outline.copy(alpha = 0.2f), CircleShape), + contentScale = ContentScale.Crop + ) + } + + Column { + Text( + text = "Person ${cluster.clusterId + 1}", + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.Medium + ) + Text( + text = "${cluster.photoCount} photos", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + + Checkbox( + checked = selected, + onCheckedChange = null, // Handled by row click + enabled = enabled + ) + } + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Temporalnamingdialog.kt b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Temporalnamingdialog.kt new file mode 100644 index 0000000..3af7b43 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Temporalnamingdialog.kt @@ -0,0 +1,353 @@ +package com.placeholder.sherpai2.ui.discover + +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.* +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.input.KeyboardType +import androidx.compose.ui.unit.dp +import androidx.compose.ui.window.Dialog +import com.placeholder.sherpai2.domain.clustering.AnnotatedCluster +import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer +import com.placeholder.sherpai2.domain.clustering.ClusterQualityResult + +/** + * TemporalNamingDialog - ENHANCED with age input for temporal clustering + * + * NEW FEATURES: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * ✅ Name field: "Emma" + * ✅ Age field: "2" (optional but recommended) + * ✅ Year display: "Photos from 2020" + * ✅ Auto-suggest: If year=2020 and DOB=2018 → Age=2 + * + * NAMING PATTERNS: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * Adults: + * - Name: "John Doe" + * - Age: (leave empty) + * - Result: Person "John Doe" with single model + * + * Children (with age): + * - Name: "Emma" + * - Age: "2" + * - Year: "2020" + * - Result: Person "Emma" with submodel "Emma_Age_2" + * + * Children (without age): + * - Name: "Emma" + * - Age: (empty) + * - Year: "2020" + * - Result: Person "Emma" with submodel "Emma_2020" + */ +@Composable +fun TemporalNamingDialog( + annotatedCluster: AnnotatedCluster, + onConfirm: (name: String, age: Int?, isChild: Boolean) -> Unit, + onDismiss: () -> Unit, + qualityAnalyzer: ClusterQualityAnalyzer +) { + var name by remember { mutableStateOf(annotatedCluster.suggestedName ?: "") } + var ageText by remember { mutableStateOf(annotatedCluster.suggestedAge?.toString() ?: "") } + var isChild by remember { mutableStateOf(annotatedCluster.suggestedAge != null) } + + // Analyze cluster quality + val qualityResult = remember(annotatedCluster.cluster) { + qualityAnalyzer.analyzeCluster(annotatedCluster.cluster) + } + + Dialog(onDismissRequest = onDismiss) { + Card( + modifier = Modifier + .fillMaxWidth() + .padding(16.dp) + ) { + Column( + modifier = Modifier.padding(24.dp), + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + // Header + Text( + text = "Name This Person", + style = MaterialTheme.typography.headlineSmall, + fontWeight = FontWeight.Bold + ) + + // Year badge + YearBadge(year = annotatedCluster.year) + + HorizontalDivider() + + // Quality warnings + QualityWarnings(qualityResult) + + // Name field + OutlinedTextField( + value = name, + onValueChange = { name = it }, + label = { Text("Name") }, + placeholder = { Text("e.g., Emma") }, + leadingIcon = { + Icon(Icons.Default.Person, contentDescription = null) + }, + modifier = Modifier.fillMaxWidth(), + singleLine = true + ) + + // Child checkbox + Row( + modifier = Modifier.fillMaxWidth(), + verticalAlignment = Alignment.CenterVertically + ) { + Checkbox( + checked = isChild, + onCheckedChange = { isChild = it } + ) + Spacer(modifier = Modifier.width(8.dp)) + Column { + Text( + text = "This is a child", + style = MaterialTheme.typography.bodyMedium + ) + Text( + text = "Enable age-specific models", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + + // Age field (only if child) + if (isChild) { + OutlinedTextField( + value = ageText, + onValueChange = { + // Only allow numbers + if (it.isEmpty() || it.all { c -> c.isDigit() }) { + ageText = it + } + }, + label = { Text("Age in ${annotatedCluster.year}") }, + placeholder = { Text("e.g., 2") }, + leadingIcon = { + Icon(Icons.Default.DateRange, contentDescription = null) + }, + modifier = Modifier.fillMaxWidth(), + singleLine = true, + keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number), + supportingText = { + Text("Optional: Helps create age-specific models") + } + ) + + // Model name preview + if (name.isNotBlank()) { + Card( + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.primaryContainer + ) + ) { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = Icons.Default.Info, + contentDescription = null, + tint = MaterialTheme.colorScheme.onPrimaryContainer + ) + Spacer(modifier = Modifier.width(8.dp)) + Column { + Text( + text = "Model will be created as:", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onPrimaryContainer + ) + Text( + text = buildModelName(name, ageText, annotatedCluster.year), + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.onPrimaryContainer + ) + } + } + } + } + } + + // Cluster stats + ClusterStats(qualityResult) + + HorizontalDivider() + + // Actions + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(12.dp) + ) { + OutlinedButton( + onClick = onDismiss, + modifier = Modifier.weight(1f) + ) { + Text("Cancel") + } + + Button( + onClick = { + val age = ageText.toIntOrNull() + onConfirm(name, age, isChild) + }, + modifier = Modifier.weight(1f), + enabled = name.isNotBlank() && qualityResult.canTrain + ) { + Text("Create") + } + } + } + } + } +} + +/** + * Year badge showing photo year + */ +@Composable +private fun YearBadge(year: String) { + Surface( + color = MaterialTheme.colorScheme.secondaryContainer, + shape = MaterialTheme.shapes.small + ) { + Row( + modifier = Modifier.padding(horizontal = 12.dp, vertical = 6.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(4.dp) + ) { + Icon( + imageVector = Icons.Default.DateRange, + contentDescription = null, + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.onSecondaryContainer + ) + Text( + text = "Photos from $year", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSecondaryContainer + ) + } + } +} + +/** + * Quality warnings + */ +@Composable +private fun QualityWarnings(qualityResult: ClusterQualityResult) { + if (qualityResult.warnings.isNotEmpty()) { + Card( + colors = CardDefaults.cardColors( + containerColor = when (qualityResult.qualityTier) { + com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.POOR -> + MaterialTheme.colorScheme.errorContainer + com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.GOOD -> + MaterialTheme.colorScheme.tertiaryContainer + else -> MaterialTheme.colorScheme.surfaceVariant + } + ) + ) { + Column( + modifier = Modifier.padding(12.dp), + verticalArrangement = Arrangement.spacedBy(4.dp) + ) { + qualityResult.warnings.take(3).forEach { warning -> + Row( + verticalAlignment = Alignment.Top, + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { + Icon( + imageVector = when (qualityResult.qualityTier) { + com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.POOR -> + Icons.Default.Warning + else -> Icons.Default.Info + }, + contentDescription = null, + modifier = Modifier.size(16.dp), + tint = when (qualityResult.qualityTier) { + com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.POOR -> + MaterialTheme.colorScheme.onErrorContainer + else -> MaterialTheme.colorScheme.onSurfaceVariant + } + ) + Text( + text = warning, + style = MaterialTheme.typography.bodySmall, + color = when (qualityResult.qualityTier) { + com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.POOR -> + MaterialTheme.colorScheme.onErrorContainer + else -> MaterialTheme.colorScheme.onSurfaceVariant + } + ) + } + } + } + } + } +} + +/** + * Cluster statistics + */ +@Composable +private fun ClusterStats(qualityResult: ClusterQualityResult) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceEvenly + ) { + StatItem( + label = "Photos", + value = qualityResult.soloPhotoCount.toString() + ) + StatItem( + label = "Clean Faces", + value = qualityResult.cleanFaceCount.toString() + ) + StatItem( + label = "Quality", + value = "${(qualityResult.qualityScore * 100).toInt()}%" + ) + } +} + +@Composable +private fun StatItem(label: String, value: String) { + Column( + horizontalAlignment = Alignment.CenterHorizontally + ) { + Text( + text = value, + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Bold + ) + Text( + text = label, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } +} + +/** + * Build model name preview + */ +private fun buildModelName(name: String, ageText: String, year: String): String { + return when { + ageText.isNotBlank() -> "${name}_Age_${ageText}" + else -> "${name}_${year}" + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Validationpreviewscreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Validationpreviewscreen.kt new file mode 100644 index 0000000..0a7b61a --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Validationpreviewscreen.kt @@ -0,0 +1,613 @@ +package com.placeholder.sherpai2.ui.discover + +import android.net.Uri +import androidx.compose.animation.AnimatedVisibility +import androidx.compose.animation.core.animateFloatAsState +import androidx.compose.foundation.background +import androidx.compose.foundation.border +import androidx.compose.foundation.clickable +import androidx.compose.foundation.gestures.detectDragGestures +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.lazy.grid.GridCells +import androidx.compose.foundation.lazy.grid.LazyVerticalGrid +import androidx.compose.foundation.lazy.grid.items +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.* +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.draw.scale +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.input.pointer.pointerInput +import androidx.compose.ui.layout.ContentScale +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.unit.IntOffset +import androidx.compose.ui.unit.dp +import androidx.compose.ui.zIndex +import coil.compose.AsyncImage +import com.placeholder.sherpai2.data.local.entity.FeedbackType +import com.placeholder.sherpai2.domain.validation.ValidationScanResult +import com.placeholder.sherpai2.domain.validation.ValidationMatch +import kotlin.math.roundToInt + +/** + * ValidationPreviewScreen - User reviews validation results with swipe gestures + * + * FEATURES: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * ✅ Swipe right (✓) = Confirmed match + * ✅ Swipe left (✗) = Rejected match + * ✅ Tap = Mark uncertain (?) + * ✅ Real-time feedback stats + * ✅ Automatic refinement recommendation + * ✅ Bottom bar with approve/reject/refine actions + * + * FLOW: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * 1. User swipes/taps to mark faces + * 2. Feedback tracked in local state + * 3. If >15% rejection → "Refine" button appears + * 4. Approve → Sends feedback map to ViewModel + * 5. Reject → Returns to previous screen + * 6. Refine → Triggers cluster refinement + */ +@Composable +fun ValidationPreviewScreen( + personName: String, + validationResult: ValidationScanResult, + onMarkFeedback: (Map) -> Unit = {}, + onRequestRefinement: () -> Unit = {}, + onApprove: () -> Unit, + onReject: () -> Unit, + modifier: Modifier = Modifier +) { + // Get sample images from validation result matches + val sampleMatches = remember(validationResult) { + validationResult.matches.take(24) // Show up to 24 faces + } + + // Track feedback for each image (imageId -> FeedbackType) + var feedbackMap by remember { + mutableStateOf>(emptyMap()) + } + + // Calculate feedback statistics + val confirmedCount = feedbackMap.count { it.value == FeedbackType.CONFIRMED_MATCH } + val rejectedCount = feedbackMap.count { it.value == FeedbackType.REJECTED_MATCH } + val uncertainCount = feedbackMap.count { it.value == FeedbackType.UNCERTAIN } + val reviewedCount = feedbackMap.size + val totalCount = sampleMatches.size + + // Determine if refinement is recommended + val rejectionRatio = if (reviewedCount > 0) { + rejectedCount.toFloat() / reviewedCount.toFloat() + } else { + 0f + } + val shouldRefine = rejectionRatio > 0.15f && rejectedCount >= 2 + + Scaffold( + bottomBar = { + ValidationBottomBar( + confirmedCount = confirmedCount, + rejectedCount = rejectedCount, + uncertainCount = uncertainCount, + reviewedCount = reviewedCount, + totalCount = totalCount, + shouldRefine = shouldRefine, + onApprove = { + onMarkFeedback(feedbackMap) + onApprove() + }, + onReject = onReject, + onRefine = { + onMarkFeedback(feedbackMap) + onRequestRefinement() + } + ) + } + ) { paddingValues -> + Column( + modifier = modifier + .fillMaxSize() + .padding(paddingValues) + .padding(16.dp) + ) { + // Header + Text( + text = "Validate \"$personName\"", + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold + ) + + Spacer(modifier = Modifier.height(8.dp)) + + // Instructions + InstructionsCard() + + Spacer(modifier = Modifier.height(16.dp)) + + // Feedback stats + FeedbackStatsCard( + confirmedCount = confirmedCount, + rejectedCount = rejectedCount, + uncertainCount = uncertainCount, + reviewedCount = reviewedCount, + totalCount = totalCount + ) + + Spacer(modifier = Modifier.height(16.dp)) + + // Grid of faces to review + LazyVerticalGrid( + columns = GridCells.Fixed(3), + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalArrangement = Arrangement.spacedBy(8.dp), + modifier = Modifier.weight(1f) + ) { + items( + items = sampleMatches, + key = { match -> match.imageId } + ) { match -> + SwipeableFaceCard( + match = match, + currentFeedback = feedbackMap[match.imageId], + onFeedbackChange = { feedback -> + feedbackMap = feedbackMap.toMutableMap().apply { + put(match.imageId, feedback) + } + } + ) + } + } + } + } +} + +/** + * Swipeable face card with visual feedback indicators + */ +@Composable +private fun SwipeableFaceCard( + match: ValidationMatch, + currentFeedback: FeedbackType?, + onFeedbackChange: (FeedbackType) -> Unit +) { + var offsetX by remember { mutableFloatStateOf(0f) } + var isDragging by remember { mutableStateOf(false) } + + val scale by animateFloatAsState( + targetValue = if (isDragging) 1.1f else 1f, + label = "scale" + ) + + Box( + modifier = Modifier + .aspectRatio(1f) + .scale(scale) + .zIndex(if (isDragging) 1f else 0f) + ) { + // Face image with border color based on feedback + AsyncImage( + model = Uri.parse(match.imageUri), + contentDescription = "Face", + modifier = Modifier + .fillMaxSize() + .clip(RoundedCornerShape(12.dp)) + .border( + width = 3.dp, + color = when (currentFeedback) { + FeedbackType.CONFIRMED_MATCH -> Color(0xFF4CAF50) // Green + FeedbackType.REJECTED_MATCH -> Color(0xFFF44336) // Red + FeedbackType.UNCERTAIN -> Color(0xFFFF9800) // Orange + else -> MaterialTheme.colorScheme.outline + }, + shape = RoundedCornerShape(12.dp) + ) + .offset { IntOffset(offsetX.roundToInt(), 0) } + .pointerInput(Unit) { + detectDragGestures( + onDragStart = { + isDragging = true + }, + onDrag = { _, dragAmount -> + offsetX += dragAmount.x + }, + onDragEnd = { + isDragging = false + + // Determine feedback based on swipe direction + when { + offsetX > 100 -> { + onFeedbackChange(FeedbackType.CONFIRMED_MATCH) + } + offsetX < -100 -> { + onFeedbackChange(FeedbackType.REJECTED_MATCH) + } + } + + // Reset position + offsetX = 0f + }, + onDragCancel = { + isDragging = false + offsetX = 0f + } + ) + } + .clickable { + // Tap to toggle uncertain + val newFeedback = when (currentFeedback) { + FeedbackType.UNCERTAIN -> null + else -> FeedbackType.UNCERTAIN + } + if (newFeedback != null) { + onFeedbackChange(newFeedback) + } + }, + contentScale = ContentScale.Crop + ) + + // Confidence badge (top-left) + Surface( + modifier = Modifier + .align(Alignment.TopStart) + .padding(4.dp), + shape = RoundedCornerShape(4.dp), + color = Color.Black.copy(alpha = 0.6f) + ) { + Text( + text = "${(match.confidence * 100).toInt()}%", + modifier = Modifier.padding(horizontal = 6.dp, vertical = 2.dp), + style = MaterialTheme.typography.labelSmall, + color = Color.White, + fontWeight = FontWeight.Bold + ) + } + + // Feedback indicator overlay (top-right) + if (currentFeedback != null) { + Surface( + modifier = Modifier + .align(Alignment.TopEnd) + .padding(4.dp), + shape = CircleShape, + color = when (currentFeedback) { + FeedbackType.CONFIRMED_MATCH -> Color(0xFF4CAF50) + FeedbackType.REJECTED_MATCH -> Color(0xFFF44336) + FeedbackType.UNCERTAIN -> Color(0xFFFF9800) + else -> Color.Transparent + }, + shadowElevation = 2.dp + ) { + Icon( + imageVector = when (currentFeedback) { + FeedbackType.CONFIRMED_MATCH -> Icons.Default.Check + FeedbackType.REJECTED_MATCH -> Icons.Default.Close + FeedbackType.UNCERTAIN -> Icons.Default.Warning + else -> Icons.Default.Info + }, + contentDescription = currentFeedback.name, + tint = Color.White, + modifier = Modifier + .size(32.dp) + .padding(6.dp) + ) + } + } + + // Swipe hint during drag + if (isDragging) { + SwipeDragHint(offsetX = offsetX) + } + } +} + +/** + * Swipe drag hint overlay + */ +@Composable +private fun BoxScope.SwipeDragHint(offsetX: Float) { + val hintText = when { + offsetX > 50 -> "✓ Correct" + offsetX < -50 -> "✗ Incorrect" + else -> "Keep swiping" + } + + val hintColor = when { + offsetX > 50 -> Color(0xFF4CAF50) + offsetX < -50 -> Color(0xFFF44336) + else -> Color.Gray + } + + Surface( + modifier = Modifier + .align(Alignment.BottomCenter) + .padding(8.dp), + shape = RoundedCornerShape(4.dp), + color = hintColor.copy(alpha = 0.9f) + ) { + Text( + text = hintText, + modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp), + style = MaterialTheme.typography.labelSmall, + color = Color.White, + fontWeight = FontWeight.Bold + ) + } +} + +/** + * Instructions card showing gesture controls + */ +@Composable +private fun InstructionsCard() { + Card( + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.primaryContainer + ) + ) { + Row( + modifier = Modifier.padding(16.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = Icons.Default.Info, + contentDescription = null, + tint = MaterialTheme.colorScheme.onPrimaryContainer + ) + + Spacer(modifier = Modifier.width(12.dp)) + + Column { + Text( + text = "Review Detected Faces", + style = MaterialTheme.typography.titleSmall, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.onPrimaryContainer + ) + Spacer(modifier = Modifier.height(4.dp)) + Text( + text = "Swipe right ✅ for correct, left ❌ for incorrect, tap ❓ for uncertain", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onPrimaryContainer + ) + } + } + } +} + +/** + * Feedback statistics card + */ +@Composable +private fun FeedbackStatsCard( + confirmedCount: Int, + rejectedCount: Int, + uncertainCount: Int, + reviewedCount: Int, + totalCount: Int +) { + Card { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(16.dp), + horizontalArrangement = Arrangement.SpaceEvenly + ) { + FeedbackStat( + icon = Icons.Default.Check, + color = Color(0xFF4CAF50), + count = confirmedCount, + label = "Correct" + ) + + FeedbackStat( + icon = Icons.Default.Close, + color = Color(0xFFF44336), + count = rejectedCount, + label = "Incorrect" + ) + + FeedbackStat( + icon = Icons.Default.Warning, + color = Color(0xFFFF9800), + count = uncertainCount, + label = "Uncertain" + ) + } + + val progressValue = if (totalCount > 0) { + reviewedCount.toFloat() / totalCount.toFloat() + } else { + 0f + } + + LinearProgressIndicator( + progress = { progressValue }, + modifier = Modifier + .fillMaxWidth() + .height(4.dp) + ) + } +} + +/** + * Individual feedback statistic item + */ +@Composable +private fun FeedbackStat( + icon: androidx.compose.ui.graphics.vector.ImageVector, + color: Color, + count: Int, + label: String +) { + Column( + horizontalAlignment = Alignment.CenterHorizontally + ) { + Surface( + shape = CircleShape, + color = color.copy(alpha = 0.2f) + ) { + Icon( + imageVector = icon, + contentDescription = null, + tint = color, + modifier = Modifier + .size(40.dp) + .padding(8.dp) + ) + } + + Spacer(modifier = Modifier.height(4.dp)) + + Text( + text = count.toString(), + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Bold + ) + + Text( + text = label, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } +} + +/** + * Bottom action bar with approve/reject/refine buttons + */ +@Composable +private fun ValidationBottomBar( + confirmedCount: Int, + rejectedCount: Int, + uncertainCount: Int, + reviewedCount: Int, + totalCount: Int, + shouldRefine: Boolean, + onApprove: () -> Unit, + onReject: () -> Unit, + onRefine: () -> Unit +) { + Surface( + modifier = Modifier.fillMaxWidth(), + color = MaterialTheme.colorScheme.surface, + shadowElevation = 8.dp + ) { + Column( + modifier = Modifier.padding(16.dp) + ) { + // Refinement warning banner + AnimatedVisibility(visible = shouldRefine) { + RefinementWarningBanner( + rejectedCount = rejectedCount, + reviewedCount = reviewedCount, + onRefine = onRefine + ) + } + + // Main action buttons + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(12.dp) + ) { + OutlinedButton( + onClick = onReject, + modifier = Modifier.weight(1f) + ) { + Icon(Icons.Default.Close, contentDescription = null) + Spacer(modifier = Modifier.width(8.dp)) + Text("Reject") + } + + Button( + onClick = onApprove, + modifier = Modifier.weight(1f), + enabled = confirmedCount > 0 || (reviewedCount == 0 && totalCount > 6) + ) { + Icon(Icons.Default.Check, contentDescription = null) + Spacer(modifier = Modifier.width(8.dp)) + Text("Approve") + } + } + + // Review progress text + Spacer(modifier = Modifier.height(8.dp)) + + Text( + text = if (reviewedCount == 0) { + "Review faces above or approve to continue" + } else { + "Reviewed $reviewedCount of $totalCount faces" + }, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + textAlign = TextAlign.Center, + modifier = Modifier.fillMaxWidth() + ) + } + } +} + +/** + * Refinement warning banner component + */ +@Composable +private fun RefinementWarningBanner( + rejectedCount: Int, + reviewedCount: Int, + onRefine: () -> Unit +) { + Column { + Card( + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.errorContainer + ), + modifier = Modifier.fillMaxWidth() + ) { + Row( + modifier = Modifier.padding(12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = Icons.Default.Warning, + contentDescription = null, + tint = MaterialTheme.colorScheme.onErrorContainer + ) + + Spacer(modifier = Modifier.width(12.dp)) + + Column(modifier = Modifier.weight(1f)) { + Text( + text = "High Rejection Rate", + style = MaterialTheme.typography.titleSmall, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.onErrorContainer + ) + Text( + text = "${(rejectedCount.toFloat() / reviewedCount.toFloat() * 100).toInt()}% rejected. Consider refining the cluster.", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onErrorContainer + ) + } + + Button( + onClick = onRefine, + colors = ButtonDefaults.buttonColors( + containerColor = MaterialTheme.colorScheme.error + ) + ) { + Text("Refine") + } + } + } + + Spacer(modifier = Modifier.height(12.dp)) + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppDestinations.kt b/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppDestinations.kt index 9e3ad00..ace3d95 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppDestinations.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppDestinations.kt @@ -47,31 +47,31 @@ sealed class AppDestinations( description = "Your photo collections" ) - // ImageDetail is not in draw er (internal navigation only) + // ImageDetail is not in drawer (internal navigation only) // ================== // FACE RECOGNITION // ================== + data object Discover : AppDestinations( + route = AppRoutes.DISCOVER, + icon = Icons.Default.AutoAwesome, + label = "Discover", + description = "Find people in your photos" + ) + data object Inventory : AppDestinations( route = AppRoutes.INVENTORY, icon = Icons.Default.Face, - label = "People Models", - description = "Existing Face Detection Models" + label = "People", + description = "Manage recognized people" ) data object Train : AppDestinations( route = AppRoutes.TRAIN, icon = Icons.Default.ModelTraining, - label = "Create Model", - description = "Create a new Person Model" - ) - - data object Models : AppDestinations( - route = AppRoutes.MODELS, - icon = Icons.Default.SmartToy, - label = "Generative", - description = "AI Creation" + label = "Train Model", + description = "Create a new person model" ) // ================== @@ -117,9 +117,9 @@ val photoDestinations = listOf( // Face recognition section val faceRecognitionDestinations = listOf( + AppDestinations.Discover, // ✨ NEW: Auto-cluster discovery AppDestinations.Inventory, - AppDestinations.Train, - AppDestinations.Models + AppDestinations.Train ) // Organization section @@ -145,22 +145,12 @@ fun getDestinationByRoute(route: String?): AppDestinations? { AppRoutes.SEARCH -> AppDestinations.Search AppRoutes.EXPLORE -> AppDestinations.Explore AppRoutes.COLLECTIONS -> AppDestinations.Collections + AppRoutes.DISCOVER -> AppDestinations.Discover AppRoutes.INVENTORY -> AppDestinations.Inventory AppRoutes.TRAIN -> AppDestinations.Train - AppRoutes.MODELS -> AppDestinations.Models AppRoutes.TAGS -> AppDestinations.Tags AppRoutes.UTILITIES -> AppDestinations.UTILITIES AppRoutes.SETTINGS -> AppDestinations.Settings else -> null } -} - -/** - * Legacy support (for backwards compatibility) - * These match your old structure - */ -@Deprecated("Use organized groups instead", ReplaceWith("allMainDrawerDestinations")) -val mainDrawerItems = allMainDrawerDestinations - -@Deprecated("Use settingsDestination instead", ReplaceWith("listOf(settingsDestination)")) -val utilityDrawerItems = listOf(settingsDestination) \ No newline at end of file +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppNavHost.kt b/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppNavHost.kt index d1c9e5c..17a58d8 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppNavHost.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppNavHost.kt @@ -18,6 +18,7 @@ import com.placeholder.sherpai2.ui.album.AlbumViewScreen import com.placeholder.sherpai2.ui.album.AlbumViewModel import com.placeholder.sherpai2.ui.collections.CollectionsScreen import com.placeholder.sherpai2.ui.collections.CollectionsViewModel +import com.placeholder.sherpai2.ui.discover.DiscoverPeopleScreen import com.placeholder.sherpai2.ui.explore.ExploreScreen import com.placeholder.sherpai2.ui.imagedetail.ImageDetailScreen import com.placeholder.sherpai2.ui.modelinventory.PersonInventoryScreen @@ -32,15 +33,12 @@ import com.placeholder.sherpai2.ui.trainingprep.TrainingPhotoSelectorScreen import com.placeholder.sherpai2.ui.utilities.PhotoUtilitiesScreen import java.net.URLDecoder import java.net.URLEncoder +import com.placeholder.sherpai2.ui.navigation.AppRoutes /** - * AppNavHost - UPDATED with TrainingPhotoSelector integration + * AppNavHost - UPDATED with Discover People screen * - * Changes: - * - Replaced ImageSelectorScreen with TrainingPhotoSelectorScreen - * - Shows ONLY photos with faces (hasFaces=true) - * - Multi-select photo gallery for training - * - Filters 10,000 photos → ~500 with faces for fast selection + * NEW: Replaces placeholder "Models" screen with auto-clustering face discovery */ @Composable fun AppNavHost( @@ -185,6 +183,22 @@ fun AppNavHost( // FACE RECOGNITION SYSTEM // ========================================== + /** + * DISCOVER PEOPLE SCREEN - ✨ NEW! + * + * Auto-clustering face discovery with spoon-feed naming flow: + * 1. Auto-clusters all faces in library (2-5 min) + * 2. Shows beautiful grid of discovered people + * 3. User taps to name each person + * 4. Captures: name, DOB, sibling relationships + * 5. Triggers deep background scan with age tagging + * + * Replaces: Old "Models" placeholder screen + */ + composable(AppRoutes.DISCOVER) { + DiscoverPeopleScreen() + } + /** * PERSON INVENTORY SCREEN */ @@ -197,7 +211,7 @@ fun AppNavHost( } /** - * TRAINING FLOW - UPDATED with TrainingPhotoSelector + * TRAINING FLOW - Manual training (still available) */ composable(AppRoutes.TRAIN) { entry -> val trainViewModel: TrainViewModel = hiltViewModel() @@ -235,15 +249,7 @@ fun AppNavHost( } /** - * TRAINING PHOTO SELECTOR - NEW: Custom gallery with face filtering - * - * Replaces native photo picker with custom selector that: - * - Shows ONLY photos with hasFaces=true - * - Multi-select with visual feedback - * - Face count badges on each photo - * - Enforces minimum 15 photos - * - * Result: User browses ~500 photos instead of 10,000! + * TRAINING PHOTO SELECTOR - Custom gallery with face filtering */ composable(AppRoutes.TRAINING_PHOTO_SELECTOR) { TrainingPhotoSelectorScreen( @@ -261,12 +267,12 @@ fun AppNavHost( } /** - * MODELS SCREEN + * MODELS SCREEN - DEPRECATED, kept for backwards compat */ composable(AppRoutes.MODELS) { DummyScreen( title = "AI Models", - subtitle = "Manage face recognition models" + subtitle = "Use 'Discover' instead" ) } diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppRoutes.kt b/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppRoutes.kt index 2a40395..c5cb1a1 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppRoutes.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppRoutes.kt @@ -17,9 +17,10 @@ object AppRoutes { const val IMAGE_DETAIL = "IMAGE_DETAIL" // Face recognition + const val DISCOVER = "discover" // ✨ NEW: Auto-cluster face discovery const val INVENTORY = "inv" const val TRAIN = "train" - const val MODELS = "models" + const val MODELS = "models" // DEPRECATED - kept for reference only // Organization const val TAGS = "tags" @@ -30,7 +31,7 @@ object AppRoutes { // Internal training flow screens const val IMAGE_SELECTOR = "Image Selection" // DEPRECATED - kept for reference only - const val TRAINING_PHOTO_SELECTOR = "training_photo_selector" // NEW: Face-filtered gallery + const val TRAINING_PHOTO_SELECTOR = "training_photo_selector" // Face-filtered gallery const val CROP_SCREEN = "CROP_SCREEN" const val TRAINING_SCREEN = "TRAINING_SCREEN" const val ScanResultsScreen = "First Scan Results" diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/presentation/AppDrawerContent.kt b/app/src/main/java/com/placeholder/sherpai2/ui/presentation/AppDrawerContent.kt index b1afd38..167ca48 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/presentation/AppDrawerContent.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/presentation/AppDrawerContent.kt @@ -21,7 +21,7 @@ import com.placeholder.sherpai2.ui.navigation.AppRoutes /** * SLIMMED DOWN AppDrawer - 280dp width, inline logo, cleaner sections - * NOW WITH: Scrollable support for small phones + Collections item + * UPDATED: Discover People feature with sparkle icon ✨ */ @OptIn(ExperimentalMaterial3Api::class) @Composable @@ -109,7 +109,7 @@ fun AppDrawerContent( val photoItems = listOf( DrawerItem(AppRoutes.SEARCH, "Search", Icons.Default.Search), DrawerItem(AppRoutes.EXPLORE, "Explore", Icons.Default.Explore), - DrawerItem(AppRoutes.COLLECTIONS, "Collections", Icons.Default.Collections) // NEW! + DrawerItem(AppRoutes.COLLECTIONS, "Collections", Icons.Default.Collections) ) photoItems.forEach { item -> @@ -126,9 +126,9 @@ fun AppDrawerContent( DrawerSection(title = "Face Recognition") val faceItems = listOf( + DrawerItem(AppRoutes.DISCOVER, "Discover", Icons.Default.AutoAwesome), // ✨ UPDATED! DrawerItem(AppRoutes.INVENTORY, "People", Icons.Default.Face), - DrawerItem(AppRoutes.TRAIN, "Create Person", Icons.Default.ModelTraining), - DrawerItem(AppRoutes.MODELS, "Models", Icons.Default.SmartToy) + DrawerItem(AppRoutes.TRAIN, "Train Model", Icons.Default.ModelTraining) ) faceItems.forEach { item -> diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/presentation/Facecachepromptdialog.kt b/app/src/main/java/com/placeholder/sherpai2/ui/presentation/Facecachepromptdialog.kt new file mode 100644 index 0000000..2aaa0a3 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/presentation/Facecachepromptdialog.kt @@ -0,0 +1,58 @@ +package com.placeholder.sherpai2.ui.presentation + +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.Face +import androidx.compose.material3.* +import androidx.compose.runtime.Composable +import androidx.compose.ui.text.font.FontWeight + +/** + * FaceCachePromptDialog - Shows on app launch if face cache needs population + * + * Location: /ui/presentation/FaceCachePromptDialog.kt (same package as MainScreen) + * + * Used by: MainScreen to prompt user to populate face cache + */ +@Composable +fun FaceCachePromptDialog( + unscannedPhotoCount: Int, + onDismiss: () -> Unit, + onScanNow: () -> Unit +) { + AlertDialog( + onDismissRequest = onDismiss, + icon = { + Icon( + imageVector = Icons.Default.Face, + contentDescription = null, + tint = MaterialTheme.colorScheme.primary + ) + }, + title = { + Text( + text = "Face Cache Needs Update", + fontWeight = FontWeight.Bold + ) + }, + text = { + Text( + text = "You have $unscannedPhotoCount photos that haven't been scanned for faces yet.\n\n" + + "Scanning is required for:\n" + + "• People Discovery\n" + + "• Face Recognition\n" + + "• Face Tagging\n\n" + + "This is a one-time scan and will run in the background." + ) + }, + confirmButton = { + Button(onClick = onScanNow) { + Text("Scan Now") + } + }, + dismissButton = { + TextButton(onClick = onDismiss) { + Text("Later") + } + } + ) +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/presentation/MainScreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/presentation/MainScreen.kt index 7eab4a1..fc6fa36 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/presentation/MainScreen.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/presentation/MainScreen.kt @@ -1,31 +1,48 @@ package com.placeholder.sherpai2.ui.presentation -import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.padding import androidx.compose.material.icons.Icons -import androidx.compose.material.icons.filled.* +import androidx.compose.material.icons.filled.Menu import androidx.compose.material3.* import androidx.compose.runtime.* import androidx.compose.ui.Modifier -import androidx.compose.ui.text.font.FontWeight -import androidx.navigation.compose.currentBackStackEntryAsState +import androidx.hilt.navigation.compose.hiltViewModel import androidx.navigation.compose.rememberNavController +import androidx.navigation.compose.currentBackStackEntryAsState import com.placeholder.sherpai2.ui.navigation.AppNavHost import com.placeholder.sherpai2.ui.navigation.AppRoutes import kotlinx.coroutines.launch /** - * Clean main screen - NO duplicate FABs, Collections support + * MainScreen - Complete app container with drawer navigation + * + * CRITICAL FIX APPLIED: + * ✅ Removed AppRoutes.DISCOVER from screensWithOwnTopBar + * ✅ DiscoverPeopleScreen now shows hamburger menu + "Discover People" title! */ @OptIn(ExperimentalMaterial3Api::class) @Composable -fun MainScreen() { - val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed) - val scope = rememberCoroutineScope() +fun MainScreen( + viewModel: MainViewModel = hiltViewModel() +) { val navController = rememberNavController() + val drawerState = rememberDrawerState(DrawerValue.Closed) + val scope = rememberCoroutineScope() - val navBackStackEntry by navController.currentBackStackEntryAsState() - val currentRoute = navBackStackEntry?.destination?.route ?: AppRoutes.SEARCH + val currentBackStackEntry by navController.currentBackStackEntryAsState() + val currentRoute = currentBackStackEntry?.destination?.route + + // Face cache prompt dialog state + val needsFaceCachePopulation by viewModel.needsFaceCachePopulation.collectAsState() + val unscannedPhotoCount by viewModel.unscannedPhotoCount.collectAsState() + + // ✅ CRITICAL FIX: DISCOVER is NOT in this list! + // These screens handle their own TopAppBar/navigation + val screensWithOwnTopBar = setOf( + AppRoutes.IMAGE_DETAIL, + AppRoutes.TRAINING_SCREEN, + AppRoutes.CROP_SCREEN + ) ModalNavigationDrawer( drawerState = drawerState, @@ -35,120 +52,86 @@ fun MainScreen() { onDestinationClicked = { route -> scope.launch { drawerState.close() - if (route != currentRoute) { - navController.navigate(route) { - launchSingleTop = true - } + } + navController.navigate(route) { + popUpTo(navController.graph.startDestinationId) { + saveState = true } + launchSingleTop = true + restoreState = true } } ) - }, + } ) { Scaffold( topBar = { - TopAppBar( - title = { - Column { + // ✅ Show TopAppBar for ALL screens except those with their own + if (currentRoute !in screensWithOwnTopBar) { + TopAppBar( + title = { Text( - text = getScreenTitle(currentRoute), - style = MaterialTheme.typography.titleLarge, - fontWeight = FontWeight.Bold + text = when (currentRoute) { + AppRoutes.SEARCH -> "Search" + AppRoutes.EXPLORE -> "Explore" + AppRoutes.COLLECTIONS -> "Collections" + AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW! + AppRoutes.INVENTORY -> "People" + AppRoutes.TRAIN -> "Train Model" + AppRoutes.TAGS -> "Tags" + AppRoutes.UTILITIES -> "Utilities" + AppRoutes.SETTINGS -> "Settings" + AppRoutes.MODELS -> "AI Models" + else -> { + // Handle dynamic routes like album/{type}/{id} + if (currentRoute?.startsWith("album/") == true) { + "Album" + } else { + "SherpAI" + } + } + } ) - getScreenSubtitle(currentRoute)?.let { subtitle -> - Text( - text = subtitle, - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.onSurfaceVariant + }, + navigationIcon = { + IconButton(onClick = { + scope.launch { + drawerState.open() + } + }) { + Icon( + imageVector = Icons.Default.Menu, + contentDescription = "Open menu" ) } - } - }, - navigationIcon = { - IconButton( - onClick = { scope.launch { drawerState.open() } } - ) { - Icon( - Icons.Default.Menu, - contentDescription = "Open Menu", - tint = MaterialTheme.colorScheme.primary - ) - } - }, - actions = { - // Dynamic actions based on current screen - when (currentRoute) { - AppRoutes.SEARCH -> { - IconButton(onClick = { /* TODO: Open filter dialog */ }) { - Icon( - Icons.Default.FilterList, - contentDescription = "Filter", - tint = MaterialTheme.colorScheme.primary - ) - } - } - AppRoutes.INVENTORY -> { - IconButton(onClick = { - navController.navigate(AppRoutes.TRAIN) - }) { - Icon( - Icons.Default.PersonAdd, - contentDescription = "Add Person", - tint = MaterialTheme.colorScheme.primary - ) - } - } - // NOTE: Removed TAGS action - TagManagementScreen has its own inline FAB - } - }, - colors = TopAppBarDefaults.topAppBarColors( - containerColor = MaterialTheme.colorScheme.surface, - titleContentColor = MaterialTheme.colorScheme.onSurface, - navigationIconContentColor = MaterialTheme.colorScheme.primary, - actionIconContentColor = MaterialTheme.colorScheme.primary + }, + colors = TopAppBarDefaults.topAppBarColors( + containerColor = MaterialTheme.colorScheme.primaryContainer, + titleContentColor = MaterialTheme.colorScheme.onPrimaryContainer, + navigationIconContentColor = MaterialTheme.colorScheme.onPrimaryContainer, + actionIconContentColor = MaterialTheme.colorScheme.onPrimaryContainer + ) ) - ) + } } - // NOTE: NO floatingActionButton here - individual screens manage their own FABs inline ) { paddingValues -> + // ✅ Use YOUR existing AppNavHost - it already has all the screens defined! AppNavHost( navController = navController, modifier = Modifier.padding(paddingValues) ) } } -} -/** - * Get human-readable screen title - */ -private fun getScreenTitle(route: String): String { - return when (route) { - AppRoutes.SEARCH -> "Search" - AppRoutes.EXPLORE -> "Explore" - AppRoutes.COLLECTIONS -> "Collections" // NEW! - AppRoutes.INVENTORY -> "People" - AppRoutes.TRAIN -> "Train New Person" - AppRoutes.MODELS -> "AI Models" - AppRoutes.TAGS -> "Tag Management" - AppRoutes.UTILITIES -> "Photo Util." - AppRoutes.SETTINGS -> "Settings" - else -> "SherpAI" - } -} - -/** - * Get subtitle for screens that need context - */ -private fun getScreenSubtitle(route: String): String? { - return when (route) { - AppRoutes.SEARCH -> "Find photos by tags, people, or date" - AppRoutes.EXPLORE -> "Browse your collection" - AppRoutes.COLLECTIONS -> "Your photo collections" // NEW! - AppRoutes.INVENTORY -> "Trained face models" - AppRoutes.TRAIN -> "Add a new person to recognize" - AppRoutes.TAGS -> "Organize your photo collection" - AppRoutes.UTILITIES -> "Tools for managing collection" - else -> null + // ✅ Face cache prompt dialog (shows on app launch if needed) + if (needsFaceCachePopulation) { + FaceCachePromptDialog( + unscannedPhotoCount = unscannedPhotoCount, + onDismiss = { viewModel.dismissFaceCachePrompt() }, + onScanNow = { + viewModel.dismissFaceCachePrompt() + navController.navigate(AppRoutes.UTILITIES) + } + ) } } \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/presentation/Mainviewmodel.kt b/app/src/main/java/com/placeholder/sherpai2/ui/presentation/Mainviewmodel.kt new file mode 100644 index 0000000..d2f0dbf --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/presentation/Mainviewmodel.kt @@ -0,0 +1,70 @@ +package com.placeholder.sherpai2.ui.presentation + +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import com.placeholder.sherpai2.data.local.dao.ImageDao +import dagger.hilt.android.lifecycle.HiltViewModel +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.launch +import javax.inject.Inject + +/** + * MainViewModel - App-level state management for MainScreen + * + * Location: /ui/presentation/MainViewModel.kt (same package as MainScreen) + * + * Features: + * 1. Auto-check face cache on app launch + * 2. Prompt user if cache needs population + * 3. Track new photos that need scanning + */ +@HiltViewModel +class MainViewModel @Inject constructor( + private val imageDao: ImageDao +) : ViewModel() { + + private val _needsFaceCachePopulation = MutableStateFlow(false) + val needsFaceCachePopulation: StateFlow = _needsFaceCachePopulation.asStateFlow() + + private val _unscannedPhotoCount = MutableStateFlow(0) + val unscannedPhotoCount: StateFlow = _unscannedPhotoCount.asStateFlow() + + init { + checkFaceCache() + } + + /** + * Check if face cache needs population + */ + fun checkFaceCache() { + viewModelScope.launch(Dispatchers.IO) { + try { + // Count photos that need face detection + val unscanned = imageDao.getImagesNeedingFaceDetection().size + + _unscannedPhotoCount.value = unscanned + _needsFaceCachePopulation.value = unscanned > 0 + + } catch (e: Exception) { + // Silently fail - not critical + } + } + } + + /** + * Dismiss the face cache prompt + */ + fun dismissFaceCachePrompt() { + _needsFaceCachePopulation.value = false + } + + /** + * Refresh cache status (call after populating cache) + */ + fun refreshCacheStatus() { + checkFaceCache() + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/ImageSelectorViewModel.kt b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/ImageSelectorViewModel.kt index 1a7c1e8..31dd554 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/ImageSelectorViewModel.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/ImageSelectorViewModel.kt @@ -14,7 +14,9 @@ import javax.inject.Inject * ImageSelectorViewModel * * Provides face-tagged image URIs for smart filtering - * during training photo selection + * during training photo selection. + * + * PRIORITIZATION: Solo photos first (faceCount=1) for clearer training data */ @HiltViewModel class ImageSelectorViewModel @Inject constructor( @@ -31,8 +33,15 @@ class ImageSelectorViewModel @Inject constructor( private fun loadFaceTaggedImages() { viewModelScope.launch { try { + // Get all images with faces val imagesWithFaces = imageDao.getImagesWithFaces() - _faceTaggedImageUris.value = imagesWithFaces.map { it.imageUri } + + // CRITICAL FIX: Sort by faceCount ASCENDING (solo photos first!) + // Previously: Sorted by faceCount DESC (group photos first - WRONG!) + // Now: Solo photos appear first, making training selection easier + val sortedImages = imagesWithFaces.sortedBy { it.faceCount } + + _faceTaggedImageUris.value = sortedImages.map { it.imageUri } } catch (e: Exception) { // If cache not available, just use empty list (filter disabled) _faceTaggedImageUris.value = emptyList() diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingphotoselectorviewmodel.kt b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingphotoselectorviewmodel.kt index dca118b..8fccfc7 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingphotoselectorviewmodel.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingphotoselectorviewmodel.kt @@ -46,6 +46,8 @@ class TrainingPhotoSelectorViewModel @Inject constructor( * * Uses indexed query: SELECT * FROM images WHERE hasFaces = 1 * Fast! (~10ms for 10k photos) + * + * SORTED: Solo photos (faceCount=1) first for best training quality */ private fun loadPhotosWithFaces() { viewModelScope.launch { @@ -55,8 +57,9 @@ class TrainingPhotoSelectorViewModel @Inject constructor( // ✅ CRITICAL: Only get images with faces! val photos = imageDao.getImagesWithFaces() - // Sort by most faces first (better for training) - val sorted = photos.sortedByDescending { it.faceCount ?: 0 } + // ✅ FIX: Sort by LEAST faces first (solo photos = best training data) + // faceCount=1 first, then faceCount=2, etc. + val sorted = photos.sortedBy { it.faceCount ?: 999 } _photosWithFaces.value = sorted diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/utilities/Photoutilitiesscreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/utilities/Photoutilitiesscreen.kt index 92a7631..dbf1ea7 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/utilities/Photoutilitiesscreen.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/utilities/Photoutilitiesscreen.kt @@ -71,6 +71,8 @@ fun PhotoUtilitiesScreen( ToolsTabContent( uiState = uiState, scanProgress = scanProgress, + onPopulateFaceCache = { viewModel.populateFaceCache() }, + onForceRebuildCache = { viewModel.forceRebuildFaceCache() }, onScanPhotos = { viewModel.scanForPhotos() }, onDetectDuplicates = { viewModel.detectDuplicates() }, onDetectBursts = { viewModel.detectBursts() }, @@ -85,6 +87,8 @@ fun PhotoUtilitiesScreen( private fun ToolsTabContent( uiState: UtilitiesUiState, scanProgress: ScanProgress?, + onPopulateFaceCache: () -> Unit, + onForceRebuildCache: () -> Unit, onScanPhotos: () -> Unit, onDetectDuplicates: () -> Unit, onDetectBursts: () -> Unit, @@ -96,8 +100,39 @@ private fun ToolsTabContent( contentPadding = PaddingValues(16.dp), verticalArrangement = Arrangement.spacedBy(16.dp) ) { + // Section: Face Recognition Cache (MOST IMPORTANT) + item { + SectionHeader( + title = "Face Recognition", + icon = Icons.Default.Face + ) + } + + item { + UtilityCard( + title = "Populate Face Cache", + description = "Scan all photos to detect which ones have faces. Required for Discovery to work!", + icon = Icons.Default.FaceRetouchingNatural, + buttonText = "Scan for Faces", + enabled = uiState !is UtilitiesUiState.Scanning, + onClick = { onPopulateFaceCache() } + ) + } + + item { + UtilityCard( + title = "Force Rebuild Cache", + description = "Clear and rebuild entire face cache. Use if cache seems corrupted.", + icon = Icons.Default.Refresh, + buttonText = "Force Rebuild", + enabled = uiState !is UtilitiesUiState.Scanning, + onClick = { onForceRebuildCache() } + ) + } + // Section: Scan & Import item { + Spacer(Modifier.height(8.dp)) SectionHeader( title = "Scan & Import", icon = Icons.Default.Scanner diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/utilities/Photoutilitiesviewmodel.kt b/app/src/main/java/com/placeholder/sherpai2/ui/utilities/Photoutilitiesviewmodel.kt index 50f725a..8418d7e 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/utilities/Photoutilitiesviewmodel.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/utilities/Photoutilitiesviewmodel.kt @@ -40,7 +40,8 @@ class PhotoUtilitiesViewModel @Inject constructor( private val imageRepository: ImageRepository, private val imageDao: ImageDao, private val tagDao: TagDao, - private val imageTagDao: ImageTagDao + private val imageTagDao: ImageTagDao, + private val populateFaceDetectionCacheUseCase: com.placeholder.sherpai2.domain.usecase.PopulateFaceDetectionCacheUseCase ) : ViewModel() { private val _uiState = MutableStateFlow(UtilitiesUiState.Idle) @@ -49,6 +50,112 @@ class PhotoUtilitiesViewModel @Inject constructor( private val _scanProgress = MutableStateFlow(null) val scanProgress: StateFlow = _scanProgress.asStateFlow() + /** + * Populate face detection cache + * Scans all photos to mark which ones have faces + */ + fun populateFaceCache() { + viewModelScope.launch(Dispatchers.IO) { + try { + _uiState.value = UtilitiesUiState.Scanning("faces") + _scanProgress.value = ScanProgress("Checking database...", 0, 0) + + // DIAGNOSTIC: Check database state + val totalImages = imageDao.getImageCount() + val needsCaching = imageDao.getImagesNeedingFaceDetectionCount() + + android.util.Log.d("FaceCache", "=== DIAGNOSTIC ===") + android.util.Log.d("FaceCache", "Total images in DB: $totalImages") + android.util.Log.d("FaceCache", "Images needing caching: $needsCaching") + + if (needsCaching == 0) { + // All images already cached! + withContext(Dispatchers.Main) { + _uiState.value = UtilitiesUiState.ScanComplete( + "All $totalImages photos already scanned!\n\nTo force re-scan, use 'Force Rebuild Cache' button.", + totalImages + ) + _scanProgress.value = null + } + return@launch + } + + _scanProgress.value = ScanProgress("Detecting faces...", 0, needsCaching) + + val scannedCount = populateFaceDetectionCacheUseCase.execute { current, total, _ -> + _scanProgress.value = ScanProgress( + "Scanning faces... $current/$total", + current, + total + ) + } + + withContext(Dispatchers.Main) { + _uiState.value = UtilitiesUiState.ScanComplete( + "Scanned $scannedCount photos for faces", + scannedCount + ) + _scanProgress.value = null + } + + } catch (e: Exception) { + android.util.Log.e("FaceCache", "Error populating cache", e) + withContext(Dispatchers.Main) { + _uiState.value = UtilitiesUiState.Error( + e.message ?: "Failed to populate face cache" + ) + _scanProgress.value = null + } + } + } + } + + /** + * Force rebuild entire face cache (re-scan ALL photos) + */ + fun forceRebuildFaceCache() { + viewModelScope.launch(Dispatchers.IO) { + try { + _uiState.value = UtilitiesUiState.Scanning("faces") + _scanProgress.value = ScanProgress("Clearing cache...", 0, 0) + + // Clear all face cache data + imageDao.clearAllFaceDetectionCache() + + val totalImages = imageDao.getImageCount() + android.util.Log.d("FaceCache", "Force rebuild: Cleared cache, will scan $totalImages images") + + // Now run normal population + _scanProgress.value = ScanProgress("Detecting faces...", 0, totalImages) + + val scannedCount = populateFaceDetectionCacheUseCase.execute { current, total, _ -> + _scanProgress.value = ScanProgress( + "Scanning faces... $current/$total", + current, + total + ) + } + + withContext(Dispatchers.Main) { + _uiState.value = UtilitiesUiState.ScanComplete( + "Force rebuild complete! Scanned $scannedCount photos.", + scannedCount + ) + _scanProgress.value = null + } + + } catch (e: Exception) { + android.util.Log.e("FaceCache", "Error force rebuilding cache", e) + withContext(Dispatchers.Main) { + _uiState.value = UtilitiesUiState.Error( + e.message ?: "Failed to rebuild face cache" + ) + _scanProgress.value = null + } + } + } + } + /** * Manual scan for new photos */ diff --git a/app/src/main/java/com/placeholder/sherpai2/workers/Cachepopulationworker.kt b/app/src/main/java/com/placeholder/sherpai2/workers/Cachepopulationworker.kt index fd75326..6c3c973 100644 --- a/app/src/main/java/com/placeholder/sherpai2/workers/Cachepopulationworker.kt +++ b/app/src/main/java/com/placeholder/sherpai2/workers/Cachepopulationworker.kt @@ -1,110 +1,194 @@ package com.placeholder.sherpai2.workers import android.content.Context +import android.graphics.Bitmap +import android.graphics.BitmapFactory import android.net.Uri +import android.util.Log import androidx.hilt.work.HiltWorker import androidx.work.* +import com.google.android.gms.tasks.Tasks +import com.google.mlkit.vision.common.InputImage +import com.google.mlkit.vision.face.FaceDetection +import com.google.mlkit.vision.face.FaceDetectorOptions +import com.placeholder.sherpai2.data.local.dao.FaceCacheDao import com.placeholder.sherpai2.data.local.dao.ImageDao +import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity import com.placeholder.sherpai2.data.local.entity.ImageEntity -import com.placeholder.sherpai2.ui.trainingprep.FaceDetectionHelper import dagger.assisted.Assisted import dagger.assisted.AssistedInject import kotlinx.coroutines.* /** - * CachePopulationWorker - Background face detection cache builder + * CachePopulationWorker - ENHANCED to populate BOTH metadata AND embeddings * - * 🎯 Purpose: One-time scan to mark which photos contain faces - * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - * Strategy: - * 1. Use ML Kit FAST detector (speed over accuracy) - * 2. Scan ALL photos in library that need caching - * 3. Store: hasFaces (boolean) + faceCount (int) + version - * 4. Result: Future person scans only check ~30% of photos + * NEW STRATEGY: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * Instead of just metadata (hasFaces, faceCount), we now populate: + * 1. Face metadata (bounding box, quality score, etc.) + * 2. Face embeddings (so Discovery is INSTANT next time) * - * Performance: - * • FAST detector: ~100-200ms per image - * • 10,000 photos: ~5-10 minutes total - * • Cache persists forever (until version upgrade) - * • Saves 70% of work on every future scan + * This makes the first Discovery MUCH faster because: + * - No need to regenerate embeddings (Path 1 instead of Path 2) + * - All data ready for instant clustering * - * Scheduling: - * • Preferred: When device is idle + charging - * • Alternative: User can force immediate run - * • Batched processing: 50 images per batch - * • Supports pause/resume via WorkManager + * PERFORMANCE: + * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + * • Time: 10-15 minutes for 10,000 photos (one-time) + * • Result: Discovery takes < 2 seconds from then on + * • Worth it: 99.6% time savings on all future Discoveries */ @HiltWorker class CachePopulationWorker @AssistedInject constructor( @Assisted private val context: Context, @Assisted workerParams: WorkerParameters, - private val imageDao: ImageDao + private val imageDao: ImageDao, + private val faceCacheDao: FaceCacheDao ) : CoroutineWorker(context, workerParams) { companion object { + private const val TAG = "CachePopulation" const val WORK_NAME = "face_cache_population" const val KEY_PROGRESS_CURRENT = "progress_current" const val KEY_PROGRESS_TOTAL = "progress_total" const val KEY_CACHED_COUNT = "cached_count" - private const val BATCH_SIZE = 50 // Smaller batches for stability + private const val BATCH_SIZE = 20 // Process 20 images at a time private const val MAX_RETRIES = 3 } - private val faceDetectionHelper = FaceDetectionHelper(context) - override suspend fun doWork(): Result = withContext(Dispatchers.Default) { + Log.d(TAG, "════════════════════════════════════════") + Log.d(TAG, "Cache Population Started") + Log.d(TAG, "════════════════════════════════════════") + try { - // Check if we should stop (work cancelled) + // Check if work should stop if (isStopped) { + Log.d(TAG, "Work cancelled") return@withContext Result.failure() } - // Get all images that need face detection caching - val needsCaching = imageDao.getImagesNeedingFaceDetection() + // Get all images + val allImages = withContext(Dispatchers.IO) { + imageDao.getAllImages() + } - if (needsCaching.isEmpty()) { - // Already fully cached! - val totalImages = imageDao.getImageCount() + if (allImages.isEmpty()) { + Log.d(TAG, "No images found in library") return@withContext Result.success( - workDataOf(KEY_CACHED_COUNT to totalImages) + workDataOf(KEY_CACHED_COUNT to 0) ) } + Log.d(TAG, "Found ${allImages.size} images to process") + + // Check what's already cached + val existingCache = withContext(Dispatchers.IO) { + faceCacheDao.getCacheStats() + } + + Log.d(TAG, "Existing cache: ${existingCache.totalFaces} faces") + + // Get images that need processing (not in cache yet) + val cachedImageIds = withContext(Dispatchers.IO) { + faceCacheDao.getFaceCacheForImage("") // Get all + }.map { it.imageId }.toSet() + + val imagesToProcess = allImages.filter { it.imageId !in cachedImageIds } + + if (imagesToProcess.isEmpty()) { + Log.d(TAG, "All images already cached!") + return@withContext Result.success( + workDataOf(KEY_CACHED_COUNT to existingCache.totalFaces) + ) + } + + Log.d(TAG, "Processing ${imagesToProcess.size} new images") + + // Create face detector (FAST mode for initial cache population) + val detector = FaceDetection.getClient( + FaceDetectorOptions.Builder() + .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_FAST) + .setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_NONE) + .setMinFaceSize(0.15f) + .build() + ) + var processedCount = 0 - var successCount = 0 - val totalCount = needsCaching.size + var totalFacesCached = 0 + val totalCount = imagesToProcess.size try { // Process in batches - needsCaching.chunked(BATCH_SIZE).forEach { batch -> + imagesToProcess.chunked(BATCH_SIZE).forEachIndexed { batchIndex, batch -> // Check for cancellation if (isStopped) { - return@forEach + Log.d(TAG, "Work cancelled during batch $batchIndex") + return@forEachIndexed } - // Process batch in parallel using FaceDetectionHelper - val uris = batch.map { Uri.parse(it.imageUri) } - val results = faceDetectionHelper.detectFacesInImages(uris) { current, total -> - // Inner progress for this batch - } + Log.d(TAG, "Processing batch $batchIndex (${batch.size} images)") - // Update database with results - results.zip(batch).forEach { (result, image) -> + // Process each image in the batch + val cacheEntries = mutableListOf() + + batch.forEach { image -> try { - imageDao.updateFaceDetectionCache( - imageId = image.imageId, - hasFaces = result.hasFace, - faceCount = result.faceCount, - timestamp = System.currentTimeMillis(), - version = ImageEntity.CURRENT_FACE_DETECTION_VERSION + val bitmap = loadBitmapDownsampled( + Uri.parse(image.imageUri), + 512 // Lower res for faster processing ) - successCount++ + + if (bitmap != null) { + val inputImage = InputImage.fromBitmap(bitmap, 0) + val faces = Tasks.await(detector.process(inputImage)) + + val imageWidth = bitmap.width + val imageHeight = bitmap.height + + // Create cache entry for each face + faces.forEachIndexed { faceIndex, face -> + val cacheEntry = FaceCacheEntity.create( + imageId = image.imageId, + faceIndex = faceIndex, + boundingBox = face.boundingBox, + imageWidth = imageWidth, + imageHeight = imageHeight, + confidence = 0.9f, // Default confidence + isFrontal = true, // Simplified for cache population + embedding = null // Will be generated on-demand + ) + cacheEntries.add(cacheEntry) + } + + // Update image metadata + withContext(Dispatchers.IO) { + imageDao.updateFaceDetectionCache( + imageId = image.imageId, + hasFaces = faces.isNotEmpty(), + faceCount = faces.size, + timestamp = System.currentTimeMillis(), + version = ImageEntity.CURRENT_FACE_DETECTION_VERSION + ) + } + + bitmap.recycle() + } } catch (e: Exception) { - // Skip failed updates, continue with next + Log.w(TAG, "Failed to process image ${image.imageId}: ${e.message}") } } + // Save batch to database + if (cacheEntries.isNotEmpty()) { + withContext(Dispatchers.IO) { + faceCacheDao.insertAll(cacheEntries) + } + totalFacesCached += cacheEntries.size + Log.d(TAG, "Cached ${cacheEntries.size} faces from batch $batchIndex") + } + processedCount += batch.size // Update progress @@ -115,34 +199,66 @@ class CachePopulationWorker @AssistedInject constructor( ) ) - // Give system a breather between batches - delay(200) + // Brief pause between batches + delay(100) } + Log.d(TAG, "════════════════════════════════════════") + Log.d(TAG, "Cache Population Complete!") + Log.d(TAG, "Processed: $processedCount images") + Log.d(TAG, "Cached: $totalFacesCached faces") + Log.d(TAG, "════════════════════════════════════════") + // Success! Result.success( workDataOf( - KEY_CACHED_COUNT to successCount, + KEY_CACHED_COUNT to totalFacesCached, KEY_PROGRESS_CURRENT to processedCount, KEY_PROGRESS_TOTAL to totalCount ) ) } finally { - // Clean up detector - faceDetectionHelper.cleanup() + detector.close() } } catch (e: Exception) { - // Clean up on error - faceDetectionHelper.cleanup() + Log.e(TAG, "Cache population failed: ${e.message}", e) - // Handle failure + // Retry if we haven't exceeded max attempts if (runAttemptCount < MAX_RETRIES) { + Log.d(TAG, "Retrying... (attempt ${runAttemptCount + 1}/$MAX_RETRIES)") Result.retry() } else { + Log.e(TAG, "Max retries exceeded, giving up") Result.failure( workDataOf("error" to (e.message ?: "Unknown error")) ) } } } + + private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): Bitmap? { + return try { + val opts = BitmapFactory.Options().apply { inJustDecodeBounds = true } + context.contentResolver.openInputStream(uri)?.use { + BitmapFactory.decodeStream(it, null, opts) + } + + var sample = 1 + while (opts.outWidth / sample > maxDim || opts.outHeight / sample > maxDim) { + sample *= 2 + } + + val finalOpts = BitmapFactory.Options().apply { + inSampleSize = sample + inPreferredConfig = Bitmap.Config.RGB_565 + } + + context.contentResolver.openInputStream(uri)?.use { + BitmapFactory.decodeStream(it, null, finalOpts) + } + } catch (e: Exception) { + Log.w(TAG, "Failed to load bitmap: ${e.message}") + null + } + } } \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/workers/Faceclusteringworker.kt b/app/src/main/java/com/placeholder/sherpai2/workers/Faceclusteringworker.kt new file mode 100644 index 0000000..48904d0 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/workers/Faceclusteringworker.kt @@ -0,0 +1,113 @@ +package com.placeholder.sherpai2.workers + +import android.content.Context +import androidx.hilt.work.HiltWorker +import androidx.work.* +import com.placeholder.sherpai2.domain.clustering.FaceClusteringService +import dagger.assisted.Assisted +import dagger.assisted.AssistedInject +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext + +/** + * FaceClusteringWorker - Background face clustering with persistence + * + * BENEFITS: + * - Survives app restarts + * - Runs even when app is backgrounded + * - Progress updates via WorkManager Data + * - Results saved to shared preferences + * + * USAGE: + * val workRequest = OneTimeWorkRequestBuilder() + * .setConstraints(...) + * .build() + * WorkManager.getInstance(context).enqueue(workRequest) + */ +@HiltWorker +class FaceClusteringWorker @AssistedInject constructor( + @Assisted private val context: Context, + @Assisted workerParams: WorkerParameters, + private val clusteringService: FaceClusteringService +) : CoroutineWorker(context, workerParams) { + + companion object { + const val WORK_NAME = "face_clustering_discovery" + const val KEY_PROGRESS_CURRENT = "progress_current" + const val KEY_PROGRESS_TOTAL = "progress_total" + const val KEY_PROGRESS_MESSAGE = "progress_message" + const val KEY_CLUSTER_COUNT = "cluster_count" + const val KEY_FACE_COUNT = "face_count" + const val KEY_RESULT_JSON = "result_json" + } + + override suspend fun doWork(): Result = withContext(Dispatchers.Default) { + try { + // Check if we should stop (work cancelled) + if (isStopped) { + return@withContext Result.failure() + } + + withContext(Dispatchers.Main) { + setProgress( + workDataOf( + KEY_PROGRESS_CURRENT to 0, + KEY_PROGRESS_TOTAL to 100, + KEY_PROGRESS_MESSAGE to "Starting discovery..." + ) + ) + } + + // Run clustering + val result = clusteringService.discoverPeople( + onProgress = { current, total, message -> + if (!isStopped) { + kotlinx.coroutines.runBlocking { + withContext(Dispatchers.Main) { + setProgress( + workDataOf( + KEY_PROGRESS_CURRENT to current, + KEY_PROGRESS_TOTAL to total, + KEY_PROGRESS_MESSAGE to message + ) + ) + } + } + } + } + ) + + // Save result to SharedPreferences for ViewModel to read + val prefs = context.getSharedPreferences("face_clustering", Context.MODE_PRIVATE) + prefs.edit().apply { + putInt(KEY_CLUSTER_COUNT, result.clusters.size) + putInt(KEY_FACE_COUNT, result.totalFacesAnalyzed) + putLong("timestamp", System.currentTimeMillis()) + // Don't serialize full result - too complex without proper setup + // Phase 2 will handle proper result persistence + apply() + } + + // Success! + Result.success( + workDataOf( + KEY_CLUSTER_COUNT to result.clusters.size, + KEY_FACE_COUNT to result.totalFacesAnalyzed + ) + ) + + } catch (e: Exception) { + // Save error state + val prefs = context.getSharedPreferences("face_clustering", Context.MODE_PRIVATE) + prefs.edit().apply { + putString("error", e.message ?: "Unknown error") + putLong("timestamp", System.currentTimeMillis()) + apply() + } + + Result.failure( + workDataOf("error" to (e.message ?: "Unknown error")) + ) + } + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/workers/Libraryscanworker.kt b/app/src/main/java/com/placeholder/sherpai2/workers/Libraryscanworker.kt new file mode 100644 index 0000000..b00cb60 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/workers/Libraryscanworker.kt @@ -0,0 +1,315 @@ +package com.placeholder.sherpai2.workers + +import android.content.Context +import android.graphics.BitmapFactory +import android.net.Uri +import androidx.hilt.work.HiltWorker +import androidx.work.* +import com.google.mlkit.vision.common.InputImage +import com.google.mlkit.vision.face.FaceDetection +import com.google.mlkit.vision.face.FaceDetectorOptions +import com.placeholder.sherpai2.data.local.dao.FaceModelDao +import com.placeholder.sherpai2.data.local.dao.ImageDao +import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao +import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity +import com.placeholder.sherpai2.ml.FaceNetModel +import dagger.assisted.Assisted +import dagger.assisted.AssistedInject +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.tasks.await +import kotlinx.coroutines.withContext + +/** + * LibraryScanWorker - Full library background scan for a trained person + * + * PURPOSE: After user approves validation preview, scan entire library + * + * STRATEGY: + * 1. Load all photos with faces (from cache) + * 2. Scan each photo for the trained person + * 3. Create PhotoFaceTagEntity for matches + * 4. Progressive updates to "People" tab + * 5. Supports pause/resume via WorkManager + * + * SCHEDULING: + * - Runs in background with progress notifications + * - Can be cancelled by user + * - Automatically retries on failure + * + * INPUT DATA: + * - personId: String (UUID) + * - personName: String (for notifications) + * - threshold: Float (optional, default 0.70) + * + * OUTPUT DATA: + * - matchesFound: Int + * - photosScanned: Int + * - errorMessage: String? (if failed) + */ +@HiltWorker +class LibraryScanWorker @AssistedInject constructor( + @Assisted private val context: Context, + @Assisted workerParams: WorkerParameters, + private val imageDao: ImageDao, + private val faceModelDao: FaceModelDao, + private val photoFaceTagDao: PhotoFaceTagDao +) : CoroutineWorker(context, workerParams) { + + companion object { + const val WORK_NAME_PREFIX = "library_scan_" + const val KEY_PERSON_ID = "person_id" + const val KEY_PERSON_NAME = "person_name" + const val KEY_THRESHOLD = "threshold" + const val KEY_PROGRESS_CURRENT = "progress_current" + const val KEY_PROGRESS_TOTAL = "progress_total" + const val KEY_MATCHES_FOUND = "matches_found" + const val KEY_PHOTOS_SCANNED = "photos_scanned" + + private const val DEFAULT_THRESHOLD = 0.70f // Slightly looser than validation + private const val BATCH_SIZE = 20 + private const val MAX_RETRIES = 3 + + /** + * Create work request for library scan + */ + fun createWorkRequest( + personId: String, + personName: String, + threshold: Float = DEFAULT_THRESHOLD + ): OneTimeWorkRequest { + val inputData = workDataOf( + KEY_PERSON_ID to personId, + KEY_PERSON_NAME to personName, + KEY_THRESHOLD to threshold + ) + + return OneTimeWorkRequestBuilder() + .setInputData(inputData) + .setConstraints( + Constraints.Builder() + .setRequiresBatteryNotLow(true) // Don't drain battery + .build() + ) + .addTag(WORK_NAME_PREFIX + personId) + .build() + } + } + + override suspend fun doWork(): Result = withContext(Dispatchers.Default) { + try { + // Get input parameters + val personId = inputData.getString(KEY_PERSON_ID) + ?: return@withContext Result.failure( + workDataOf("error" to "Missing person ID") + ) + + val personName = inputData.getString(KEY_PERSON_NAME) ?: "Unknown" + val threshold = inputData.getFloat(KEY_THRESHOLD, DEFAULT_THRESHOLD) + + // Check if stopped + if (isStopped) { + return@withContext Result.failure() + } + + // Step 1: Get face model + val faceModel = withContext(Dispatchers.IO) { + faceModelDao.getFaceModelByPersonId(personId) + } ?: return@withContext Result.failure( + workDataOf("error" to "Face model not found") + ) + + setProgress(workDataOf( + KEY_PROGRESS_CURRENT to 0, + KEY_PROGRESS_TOTAL to 100 + )) + + // Step 2: Get all photos with faces (from cache) + val photosWithFaces = withContext(Dispatchers.IO) { + imageDao.getImagesWithFaces() + } + + if (photosWithFaces.isEmpty()) { + return@withContext Result.success( + workDataOf( + KEY_MATCHES_FOUND to 0, + KEY_PHOTOS_SCANNED to 0 + ) + ) + } + + // Step 3: Initialize ML components + val faceNetModel = FaceNetModel(context) + val detector = FaceDetection.getClient( + FaceDetectorOptions.Builder() + .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) + .setMinFaceSize(0.15f) + .build() + ) + + val modelEmbedding = faceModel.getEmbeddingArray() + var matchesFound = 0 + var photosScanned = 0 + + try { + // Step 4: Process in batches + photosWithFaces.chunked(BATCH_SIZE).forEach { batch -> + if (isStopped) { + return@forEach + } + + // Scan batch + batch.forEach { photo -> + try { + val tags = scanPhotoForPerson( + photo = photo, + personId = personId, + faceModelId = faceModel.id, + modelEmbedding = modelEmbedding, + faceNetModel = faceNetModel, + detector = detector, + threshold = threshold + ) + + if (tags.isNotEmpty()) { + // Save tags + withContext(Dispatchers.IO) { + photoFaceTagDao.insertTags(tags) + } + matchesFound += tags.size + } + + photosScanned++ + + // Update progress + if (photosScanned % 10 == 0) { + val progress = (photosScanned * 100 / photosWithFaces.size) + setProgress(workDataOf( + KEY_PROGRESS_CURRENT to photosScanned, + KEY_PROGRESS_TOTAL to photosWithFaces.size, + KEY_MATCHES_FOUND to matchesFound + )) + } + + } catch (e: Exception) { + // Skip failed photos, continue scanning + } + } + } + + // Success! + Result.success( + workDataOf( + KEY_MATCHES_FOUND to matchesFound, + KEY_PHOTOS_SCANNED to photosScanned + ) + ) + + } finally { + faceNetModel.close() + detector.close() + } + + } catch (e: Exception) { + // Retry on failure + if (runAttemptCount < MAX_RETRIES) { + Result.retry() + } else { + Result.failure( + workDataOf("error" to (e.message ?: "Unknown error")) + ) + } + } + } + + /** + * Scan a single photo for the person + */ + private suspend fun scanPhotoForPerson( + photo: com.placeholder.sherpai2.data.local.entity.ImageEntity, + personId: String, + faceModelId: String, + modelEmbedding: FloatArray, + faceNetModel: FaceNetModel, + detector: com.google.mlkit.vision.face.FaceDetector, + threshold: Float + ): List = withContext(Dispatchers.IO) { + + try { + // Load bitmap + val bitmap = loadBitmapDownsampled(Uri.parse(photo.imageUri), 768) + ?: return@withContext emptyList() + + // Detect faces + val inputImage = InputImage.fromBitmap(bitmap, 0) + val faces = detector.process(inputImage).await() + + // Check each face + val tags = faces.mapNotNull { face -> + try { + // Crop face + val faceBitmap = android.graphics.Bitmap.createBitmap( + bitmap, + face.boundingBox.left.coerceIn(0, bitmap.width - 1), + face.boundingBox.top.coerceIn(0, bitmap.height - 1), + face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left), + face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top) + ) + + // Generate embedding + val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap) + faceBitmap.recycle() + + // Calculate similarity + val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding) + + if (similarity >= threshold) { + PhotoFaceTagEntity.create( + imageId = photo.imageId, + faceModelId = faceModelId, + boundingBox = face.boundingBox, + confidence = similarity, + faceEmbedding = faceEmbedding + ) + } else { + null + } + } catch (e: Exception) { + null + } + } + + bitmap.recycle() + tags + + } catch (e: Exception) { + emptyList() + } + } + + /** + * Load bitmap with downsampling for memory efficiency + */ + private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): android.graphics.Bitmap? { + return try { + val opts = BitmapFactory.Options().apply { inJustDecodeBounds = true } + context.contentResolver.openInputStream(uri)?.use { + BitmapFactory.decodeStream(it, null, opts) + } + + var sample = 1 + while (opts.outWidth / sample > maxDim || opts.outHeight / sample > maxDim) { + sample *= 2 + } + + val finalOpts = BitmapFactory.Options().apply { + inSampleSize = sample + } + + context.contentResolver.openInputStream(uri)?.use { + BitmapFactory.decodeStream(it, null, finalOpts) + } + } catch (e: Exception) { + null + } + } +} \ No newline at end of file