diff --git a/app/build.gradle.kts b/app/build.gradle.kts index d304993..6e693ef 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -77,5 +77,12 @@ dependencies { implementation(libs.mlkit.face.detection) implementation(libs.kotlinx.coroutines.play.services) + //Face Rec + implementation(libs.tensorflow.lite) + implementation(libs.tensorflow.lite.support) + // Optional: GPU acceleration + implementation(libs.tensorflow.lite.gpu) + // Gson for storing FloatArrays in Room + implementation(libs.gson) } \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/AppDatabase.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/AppDatabase.kt index c42000a..6e42735 100644 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/AppDatabase.kt +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/AppDatabase.kt @@ -2,46 +2,50 @@ package com.placeholder.sherpai2.data.local import androidx.room.Database import androidx.room.RoomDatabase -import com.placeholder.sherpai2.data.local.dao.EventDao -import com.placeholder.sherpai2.data.local.dao.ImageAggregateDao -import com.placeholder.sherpai2.data.local.dao.ImageDao -import com.placeholder.sherpai2.data.local.dao.ImageEventDao -import com.placeholder.sherpai2.data.local.dao.ImagePersonDao -import com.placeholder.sherpai2.data.local.dao.ImageTagDao -import com.placeholder.sherpai2.data.local.dao.PersonDao -import com.placeholder.sherpai2.data.local.dao.TagDao -import com.placeholder.sherpai2.data.local.entity.EventEntity -import com.placeholder.sherpai2.data.local.entity.ImageEntity -import com.placeholder.sherpai2.data.local.entity.ImageEventEntity -import com.placeholder.sherpai2.data.local.entity.ImagePersonEntity -import com.placeholder.sherpai2.data.local.entity.ImageTagEntity -import com.placeholder.sherpai2.data.local.entity.PersonEntity -import com.placeholder.sherpai2.data.local.entity.TagEntity +import com.placeholder.sherpai2.data.local.dao.* +import com.placeholder.sherpai2.data.local.entity.* +/** + * AppDatabase - Complete database for SherpAI2 + * + * ENTITIES: + * - YOUR EXISTING: Image, Tag, Event, junction tables + * - NEW: PersonEntity (people in your app) + * - NEW: FaceModelEntity (face embeddings, links to PersonEntity) + * - NEW: PhotoFaceTagEntity (face detections, links to ImageEntity + FaceModelEntity) + */ @Database( entities = [ + // ===== YOUR EXISTING ENTITIES ===== ImageEntity::class, TagEntity::class, - PersonEntity::class, EventEntity::class, ImageTagEntity::class, ImagePersonEntity::class, - ImageEventEntity::class - ], - version = 1, - exportSchema = true -) + ImageEventEntity::class, + // ===== NEW ENTITIES ===== + PersonEntity::class, // NEW: People + FaceModelEntity::class, // NEW: Face embeddings + PhotoFaceTagEntity::class // NEW: Face tags + ], + version = 3, + exportSchema = false +) +// No TypeConverters needed - embeddings stored as strings abstract class AppDatabase : RoomDatabase() { + // ===== YOUR EXISTING DAOs ===== abstract fun imageDao(): ImageDao abstract fun tagDao(): TagDao - abstract fun personDao(): PersonDao abstract fun eventDao(): EventDao - abstract fun imageTagDao(): ImageTagDao abstract fun imagePersonDao(): ImagePersonDao abstract fun imageEventDao(): ImageEventDao - abstract fun imageAggregateDao(): ImageAggregateDao + + // ===== NEW DAOs ===== + abstract fun personDao(): PersonDao // NEW: Manage people + abstract fun faceModelDao(): FaceModelDao // NEW: Manage face embeddings + abstract fun photoFaceTagDao(): PhotoFaceTagDao // NEW: Manage face tags } \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/FaceModelDao.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/FaceModelDao.kt new file mode 100644 index 0000000..45db325 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/FaceModelDao.kt @@ -0,0 +1,44 @@ +package com.placeholder.sherpai2.data.local.dao + +import androidx.room.* +import kotlinx.coroutines.flow.Flow +import com.placeholder.sherpai2.data.local.entity.FaceModelEntity +/** + * FaceModelDao - Manages face recognition models + * + * PRIMARY KEY TYPE: String (UUID) + * FOREIGN KEY: personId (String) + */ +@Dao +interface FaceModelDao { + + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insertFaceModel(faceModel: FaceModelEntity): Long // Row ID + + @Update + suspend fun updateFaceModel(faceModel: FaceModelEntity) + + @Query("UPDATE face_models SET lastUsed = :timestamp WHERE id = :faceModelId") + suspend fun updateLastUsed(faceModelId: String, timestamp: Long) + + @Query("SELECT * FROM face_models WHERE id = :faceModelId") + suspend fun getFaceModelById(faceModelId: String): FaceModelEntity? + + @Query("SELECT * FROM face_models WHERE personId = :personId AND isActive = 1") + suspend fun getFaceModelByPersonId(personId: String): FaceModelEntity? + + @Query("SELECT * FROM face_models WHERE isActive = 1 ORDER BY lastUsed DESC") + suspend fun getAllActiveFaceModels(): List + + @Query("SELECT * FROM face_models WHERE isActive = 1 ORDER BY lastUsed DESC") + fun getAllActiveFaceModelsFlow(): Flow> + + @Query("DELETE FROM face_models WHERE id = :faceModelId") + suspend fun deleteFaceModelById(faceModelId: String) + + @Query("UPDATE face_models SET isActive = 0 WHERE id = :faceModelId") + suspend fun deactivateFaceModel(faceModelId: String) + + @Query("SELECT COUNT(*) FROM face_models WHERE isActive = 1") + suspend fun getActiveFaceModelCount(): Int +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/ImageDao.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/ImageDao.kt index 6589f0f..21f6460 100644 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/ImageDao.kt +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/ImageDao.kt @@ -65,4 +65,11 @@ interface ImageDao { @Insert(onConflict = OnConflictStrategy.IGNORE) suspend fun insert(image: ImageEntity) -} + + /** + * Get images by list of IDs. + * FIXED: Changed from List to List to match ImageEntity.imageId type + */ + @Query("SELECT * FROM images WHERE imageId IN (:imageIds)") + suspend fun getImagesByIds(imageIds: List): List +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/PersonDao.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/PersonDao.kt index 4f56365..87c8c8a 100644 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/PersonDao.kt +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/PersonDao.kt @@ -1,24 +1,53 @@ package com.placeholder.sherpai2.data.local.dao -import androidx.room.Dao -import androidx.room.Insert -import androidx.room.OnConflictStrategy -import androidx.room.Query +import androidx.room.* import com.placeholder.sherpai2.data.local.entity.PersonEntity +import kotlinx.coroutines.flow.Flow +/** + * PersonDao - Data access for PersonEntity + * + * PRIMARY KEY TYPE: String (UUID) + */ @Dao interface PersonDao { @Insert(onConflict = OnConflictStrategy.REPLACE) - suspend fun insert(person: PersonEntity) + suspend fun insert(person: PersonEntity): Long // Room still returns row ID as Long - @Query("SELECT * FROM persons WHERE personId = :personId") - suspend fun getById(personId: String): PersonEntity? + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insertAll(persons: List) - @Query(""" - SELECT * FROM persons - WHERE isHidden = 0 - ORDER BY displayName - """) - suspend fun getVisiblePeople(): List -} + @Update + suspend fun update(person: PersonEntity) + + @Query("UPDATE persons SET updatedAt = :timestamp WHERE id = :personId") + suspend fun updateTimestamp(personId: String, timestamp: Long = System.currentTimeMillis()) + + @Delete + suspend fun delete(person: PersonEntity) + + @Query("DELETE FROM persons WHERE id = :personId") + suspend fun deleteById(personId: String) + + @Query("SELECT * FROM persons WHERE id = :personId") + suspend fun getPersonById(personId: String): PersonEntity? + + @Query("SELECT * FROM persons WHERE id IN (:personIds)") + suspend fun getPersonsByIds(personIds: List): List + + @Query("SELECT * FROM persons ORDER BY name ASC") + suspend fun getAllPersons(): List + + @Query("SELECT * FROM persons ORDER BY name ASC") + fun getAllPersonsFlow(): Flow> + + @Query("SELECT * FROM persons WHERE name LIKE '%' || :query || '%' ORDER BY name ASC") + suspend fun searchByName(query: String): List + + @Query("SELECT COUNT(*) FROM persons") + suspend fun getPersonCount(): Int + + @Query("SELECT EXISTS(SELECT 1 FROM persons WHERE id = :personId)") + suspend fun personExists(personId: String): Boolean +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Photofacetagdao.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Photofacetagdao.kt new file mode 100644 index 0000000..d873d68 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Photofacetagdao.kt @@ -0,0 +1,91 @@ +package com.placeholder.sherpai2.data.local.dao + +import androidx.room.* +import kotlinx.coroutines.flow.Flow +import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity + +/** + * PhotoFaceTagDao - Manages face tags in photos + * + * PRIMARY KEY TYPE: String (UUID) + * FOREIGN KEYS: imageId (String), faceModelId (String) + */ +@Dao +interface PhotoFaceTagDao { + + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insertTag(tag: PhotoFaceTagEntity): Long // Row ID + + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insertTags(tags: List) + + @Update + suspend fun updateTag(tag: PhotoFaceTagEntity) + + @Query("UPDATE photo_face_tags SET verifiedByUser = 1, verifiedAt = :timestamp WHERE id = :tagId") + suspend fun markTagAsVerified(tagId: String, timestamp: Long = System.currentTimeMillis()) + + // ===== QUERY BY IMAGE ===== + + @Query("SELECT * FROM photo_face_tags WHERE imageId = :imageId") + suspend fun getTagsForImage(imageId: String): List + + @Query("SELECT COUNT(*) FROM photo_face_tags WHERE imageId = :imageId") + suspend fun getFaceCountForImage(imageId: String): Int + + @Query("SELECT EXISTS(SELECT 1 FROM photo_face_tags WHERE imageId = :imageId AND faceModelId = :faceModelId)") + suspend fun imageHasPerson(imageId: String, faceModelId: String): Boolean + + // ===== QUERY BY FACE MODEL ===== + + @Query("SELECT DISTINCT imageId FROM photo_face_tags WHERE faceModelId = :faceModelId ORDER BY detectedAt DESC") + suspend fun getImageIdsForFaceModel(faceModelId: String): List + + @Query("SELECT DISTINCT imageId FROM photo_face_tags WHERE faceModelId = :faceModelId ORDER BY detectedAt DESC") + fun getImageIdsForFaceModelFlow(faceModelId: String): Flow> + + @Query("SELECT faceModelId, COUNT(DISTINCT imageId) as photoCount FROM photo_face_tags GROUP BY faceModelId") + suspend fun getPhotoCountPerFaceModel(): List + + @Query("SELECT * FROM photo_face_tags WHERE faceModelId = :faceModelId ORDER BY detectedAt DESC") + suspend fun getAllTagsForFaceModel(faceModelId: String): List + + // ===== DELETE ===== + + @Delete + suspend fun deleteTag(tag: PhotoFaceTagEntity) + + @Query("DELETE FROM photo_face_tags WHERE id = :tagId") + suspend fun deleteTagById(tagId: String) + + @Query("DELETE FROM photo_face_tags WHERE faceModelId = :faceModelId") + suspend fun deleteTagsForFaceModel(faceModelId: String) + + @Query("DELETE FROM photo_face_tags WHERE imageId = :imageId") + suspend fun deleteTagsForImage(imageId: String) + + // ===== STATISTICS ===== + + @Query("SELECT * FROM photo_face_tags WHERE confidence < :threshold ORDER BY confidence ASC") + suspend fun getLowConfidenceTags(threshold: Float = 0.7f): List + + @Query("SELECT * FROM photo_face_tags WHERE verifiedByUser = 0 ORDER BY detectedAt DESC") + suspend fun getUnverifiedTags(): List + + @Query("SELECT COUNT(*) FROM photo_face_tags WHERE verifiedByUser = 0") + suspend fun getUnverifiedTagCount(): Int + + @Query("SELECT AVG(confidence) FROM photo_face_tags WHERE faceModelId = :faceModelId") + suspend fun getAverageConfidenceForFaceModel(faceModelId: String): Float? + + @Query("SELECT * FROM photo_face_tags ORDER BY detectedAt DESC LIMIT :limit") + suspend fun getRecentlyDetectedFaces(limit: Int = 20): List +} + +/** + * Simple data class for photo counts + */ +data class FaceModelPhotoCount( + val faceModelId: String, + val photoCount: Int +) \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Facerecognitionentities.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Facerecognitionentities.kt new file mode 100644 index 0000000..2f99432 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Facerecognitionentities.kt @@ -0,0 +1,155 @@ +package com.placeholder.sherpai2.data.local.entity + +import androidx.room.Entity +import androidx.room.ForeignKey +import androidx.room.Index +import androidx.room.PrimaryKey +import java.util.UUID + +/** + * PersonEntity - Represents a person in the face recognition system + * + * TABLE: persons + * PRIMARY KEY: id (String) + */ +@Entity( + tableName = "persons", + indices = [ + Index(value = ["name"]) + ] +) +data class PersonEntity( + @PrimaryKey + val id: String = UUID.randomUUID().toString(), + + val name: String, + val createdAt: Long = System.currentTimeMillis(), + val updatedAt: Long = System.currentTimeMillis() +) + +/** + * FaceModelEntity - Stores face recognition model (embedding) for a person + * + * TABLE: face_models + * FOREIGN KEY: personId → persons.id + */ +@Entity( + tableName = "face_models", + foreignKeys = [ + ForeignKey( + entity = PersonEntity::class, + parentColumns = ["id"], + childColumns = ["personId"], + onDelete = ForeignKey.CASCADE + ) + ], + indices = [ + Index(value = ["personId"], unique = true) + ] +) +data class FaceModelEntity( + @PrimaryKey + val id: String = UUID.randomUUID().toString(), + + val personId: String, + val embedding: String, // Serialized FloatArray + val trainingImageCount: Int, + val averageConfidence: Float, + val createdAt: Long = System.currentTimeMillis(), + val updatedAt: Long = System.currentTimeMillis(), + val lastUsed: Long? = null, + val isActive: Boolean = true +) { + companion object { + fun create( + personId: String, + embeddingArray: FloatArray, + trainingImageCount: Int, + averageConfidence: Float + ): FaceModelEntity { + return FaceModelEntity( + personId = personId, + embedding = embeddingArray.joinToString(","), + trainingImageCount = trainingImageCount, + averageConfidence = averageConfidence + ) + } + } + + fun getEmbeddingArray(): FloatArray { + return embedding.split(",").map { it.toFloat() }.toFloatArray() + } +} + +/** + * PhotoFaceTagEntity - Links detected faces in photos to person models + * + * TABLE: photo_face_tags + * FOREIGN KEYS: + * - imageId → images.imageId (String) + * - faceModelId → face_models.id (String) + */ +@Entity( + tableName = "photo_face_tags", + foreignKeys = [ + ForeignKey( + entity = ImageEntity::class, + parentColumns = ["imageId"], + childColumns = ["imageId"], + onDelete = ForeignKey.CASCADE + ), + ForeignKey( + entity = FaceModelEntity::class, + parentColumns = ["id"], + childColumns = ["faceModelId"], + onDelete = ForeignKey.CASCADE + ) + ], + indices = [ + Index(value = ["imageId"]), + Index(value = ["faceModelId"]), + Index(value = ["imageId", "faceModelId"]) + ] +) +data class PhotoFaceTagEntity( + @PrimaryKey + val id: String = UUID.randomUUID().toString(), + + val imageId: String, // String to match ImageEntity.imageId + val faceModelId: String, + + val boundingBox: String, // "left,top,right,bottom" + val confidence: Float, + val embedding: String, // Serialized FloatArray + + val detectedAt: Long = System.currentTimeMillis(), + val verifiedByUser: Boolean = false, + val verifiedAt: Long? = null +) { + companion object { + fun create( + imageId: String, + faceModelId: String, + boundingBox: android.graphics.Rect, + confidence: Float, + faceEmbedding: FloatArray + ): PhotoFaceTagEntity { + return PhotoFaceTagEntity( + imageId = imageId, + faceModelId = faceModelId, + boundingBox = "${boundingBox.left},${boundingBox.top},${boundingBox.right},${boundingBox.bottom}", + confidence = confidence, + embedding = faceEmbedding.joinToString(",") + ) + } + } + + fun getBoundingBox(): android.graphics.Rect { + val parts = boundingBox.split(",").map { it.toInt() } + return android.graphics.Rect(parts[0], parts[1], parts[2], parts[3]) + } + + fun getEmbeddingArray(): FloatArray { + return embedding.split(",").map { it.toFloat() }.toFloatArray() + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/entity/ImagePersonEntity.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/ImagePersonEntity.kt index e327675..f1b3142 100644 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/entity/ImagePersonEntity.kt +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/ImagePersonEntity.kt @@ -16,7 +16,7 @@ import androidx.room.Index ), ForeignKey( entity = PersonEntity::class, - parentColumns = ["personId"], + parentColumns = ["id"], childColumns = ["personId"], onDelete = ForeignKey.CASCADE ) diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/entity/PersonEntity b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/PersonEntity new file mode 100644 index 0000000..dca01dd --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/PersonEntity @@ -0,0 +1,49 @@ +package com.placeholder.sherpai2.data.local.entity + +import androidx.room.Entity +import androidx.room.PrimaryKey + +/** + * PersonEntity - Represents a person in your app + * + * This is a SIMPLE person entity for your existing database. + * Face embeddings are stored separately in FaceModelEntity. + * + * ARCHITECTURE: + * - PersonEntity = Human data (name, birthday, etc.) + * - FaceModelEntity = AI data (face embeddings) - links to this via personId + * + * You can add more fields as needed: + * - birthday: Long? + * - phoneNumber: String? + * - email: String? + * - notes: String? + * - etc. + */ +@Entity(tableName = "persons") +data class PersonEntity( + @PrimaryKey(autoGenerate = true) + val id: Long = 0, + + /** + * Person's name + */ + val name: String, + + /** + * When this person was added + */ + val createdAt: Long = System.currentTimeMillis(), + + /** + * Last time this person's data was updated + */ + val updatedAt: Long = System.currentTimeMillis() + + // ADD MORE FIELDS AS NEEDED: + // val birthday: Long? = null, + // val phoneNumber: String? = null, + // val email: String? = null, + // val profilePhotoUri: String? = null, + // val notes: String? = null +) \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/entity/PersonsEntity.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/PersonsEntity.kt deleted file mode 100644 index 1360a3f..0000000 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/entity/PersonsEntity.kt +++ /dev/null @@ -1,30 +0,0 @@ -package com.placeholder.sherpai2.data.local.entity - -import androidx.room.Entity -import androidx.room.PrimaryKey - -/** - * Represents a known person. - * - * People are separate from generic tags because: - * - face embeddings - * - privacy rules - * - identity merging - */ -@Entity(tableName = "persons") -data class PersonEntity( - - @PrimaryKey - val personId: String, - - val displayName: String, - - /** - * Reference to face embedding storage (ML layer). - */ - val faceEmbeddingId: String?, - - val isHidden: Boolean, - - val createdAt: Long -) diff --git a/app/src/main/java/com/placeholder/sherpai2/data/repository/Facerecognitionrepository.kt b/app/src/main/java/com/placeholder/sherpai2/data/repository/Facerecognitionrepository.kt new file mode 100644 index 0000000..a50ccf0 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/data/repository/Facerecognitionrepository.kt @@ -0,0 +1,357 @@ +package com.placeholder.sherpai2.data.repository + +import android.content.Context +import android.graphics.Bitmap +import com.placeholder.sherpai2.data.local.dao.FaceModelDao +import com.placeholder.sherpai2.data.local.dao.ImageDao +import com.placeholder.sherpai2.data.local.dao.PersonDao +import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao +import com.placeholder.sherpai2.data.local.entity.* +import com.placeholder.sherpai2.ml.FaceNetModel +import com.placeholder.sherpai2.ui.trainingprep.TrainingSanityChecker +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.withContext +import javax.inject.Inject +import javax.inject.Singleton + +/** + * FaceRecognitionRepository - Complete face recognition system + * + * USES STRING IDs TO MATCH YOUR SCHEMA: + * - PersonEntity.id: String (UUID) + * - ImageEntity.imageId: String + * - FaceModelEntity.id: String (UUID) + * - PhotoFaceTagEntity.id: String (UUID) + */ +@Singleton +class FaceRecognitionRepository @Inject constructor( + private val context: Context, + private val personDao: PersonDao, + private val imageDao: ImageDao, + private val faceModelDao: FaceModelDao, + private val photoFaceTagDao: PhotoFaceTagDao +) { + + private val faceNetModel by lazy { FaceNetModel(context) } + + // ====================== + // TRAINING OPERATIONS + // ====================== + + /** + * Create a new person with face model in one operation. + * + * @return PersonId (String UUID) + */ + suspend fun createPersonWithFaceModel( + personName: String, + validImages: List, + onProgress: (Int, Int) -> Unit = { _, _ -> } + ): String = withContext(Dispatchers.IO) { + + // Create PersonEntity with UUID + val person = PersonEntity(name = personName) + personDao.insert(person) + + // Train face model + trainPerson( + personId = person.id, + validImages = validImages, + onProgress = onProgress + ) + + person.id + } + + /** + * Train a face recognition model for an existing person. + * + * @param personId String UUID + * @return Face model ID (String UUID) + */ + suspend fun trainPerson( + personId: String, + validImages: List, + onProgress: (Int, Int) -> Unit = { _, _ -> } + ): String = withContext(Dispatchers.Default) { + + val person = personDao.getPersonById(personId) + ?: throw IllegalArgumentException("Person with ID $personId not found") + + val embeddings = faceNetModel.generateEmbeddingsBatch( + faceBitmaps = validImages.map { it.croppedFaceBitmap }, + onProgress = onProgress + ) + + val personEmbedding = faceNetModel.createPersonModel(embeddings) + + val confidences = embeddings.map { embedding -> + faceNetModel.calculateSimilarity(personEmbedding, embedding) + } + val avgConfidence = confidences.average().toFloat() + + val faceModel = FaceModelEntity.create( + personId = personId, + embeddingArray = personEmbedding, + trainingImageCount = validImages.size, + averageConfidence = avgConfidence + ) + + faceModelDao.insertFaceModel(faceModel) + faceModel.id + } + + /** + * Retrain face model with additional images. + */ + suspend fun retrainFaceModel( + faceModelId: String, + newFaceImages: List + ) = withContext(Dispatchers.Default) { + + val faceModel = faceModelDao.getFaceModelById(faceModelId) + ?: throw IllegalArgumentException("Face model $faceModelId not found") + + val existingEmbedding = faceModel.getEmbeddingArray() + val newEmbeddings = faceNetModel.generateEmbeddingsBatch(newFaceImages) + val allEmbeddings = listOf(existingEmbedding) + newEmbeddings + val updatedEmbedding = faceNetModel.createPersonModel(allEmbeddings) + + val confidences = allEmbeddings.map { embedding -> + faceNetModel.calculateSimilarity(updatedEmbedding, embedding) + } + val avgConfidence = confidences.average().toFloat() + + faceModelDao.updateFaceModel( + FaceModelEntity.create( + personId = faceModel.personId, + embeddingArray = updatedEmbedding, + trainingImageCount = faceModel.trainingImageCount + newFaceImages.size, + averageConfidence = avgConfidence + ).copy( + id = faceModelId, + createdAt = faceModel.createdAt, + updatedAt = System.currentTimeMillis() + ) + ) + } + + // ====================== + // SCANNING / RECOGNITION + // ====================== + + /** + * Scan an image for faces and tag recognized persons. + * + * @param imageId String (from ImageEntity.imageId) + */ + suspend fun scanImage( + imageId: String, + detectedFaces: List, + threshold: Float = FaceNetModel.SIMILARITY_THRESHOLD_HIGH + ): List = withContext(Dispatchers.Default) { + + val faceModels = faceModelDao.getAllActiveFaceModels() + + if (faceModels.isEmpty()) { + return@withContext emptyList() + } + + val tags = mutableListOf() + + for (detectedFace in detectedFaces) { + val faceEmbedding = faceNetModel.generateEmbedding(detectedFace.croppedBitmap) + + var bestMatch: Pair? = null + var highestSimilarity = threshold + + for (faceModel in faceModels) { + val modelEmbedding = faceModel.getEmbeddingArray() + val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding) + + if (similarity > highestSimilarity) { + highestSimilarity = similarity + bestMatch = Pair(faceModel.id, similarity) + } + } + + if (bestMatch != null) { + val (faceModelId, confidence) = bestMatch + + val tag = PhotoFaceTagEntity.create( + imageId = imageId, + faceModelId = faceModelId, + boundingBox = detectedFace.boundingBox, + confidence = confidence, + faceEmbedding = faceEmbedding + ) + + tags.add(tag) + faceModelDao.updateLastUsed(faceModelId, System.currentTimeMillis()) + } + } + + if (tags.isNotEmpty()) { + photoFaceTagDao.insertTags(tags) + } + + tags + } + + /** + * Recognize a single face bitmap (without saving). + */ + suspend fun recognizeFace( + faceBitmap: Bitmap, + threshold: Float = FaceNetModel.SIMILARITY_THRESHOLD_HIGH + ): Pair? = withContext(Dispatchers.Default) { + + val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap) + val faceModels = faceModelDao.getAllActiveFaceModels() + val modelEmbeddings = faceModels.map { it.id to it.getEmbeddingArray() } + + faceNetModel.findBestMatch(faceEmbedding, modelEmbeddings, threshold) + } + + // ====================== + // SEARCH / QUERY + // ====================== + + /** + * Get all images containing a specific person. + * + * @param personId String UUID + */ + suspend fun getImagesForPerson(personId: String): List = withContext(Dispatchers.IO) { + + val faceModel = faceModelDao.getFaceModelByPersonId(personId) + ?: return@withContext emptyList() + + val imageIds = photoFaceTagDao.getImageIdsForFaceModel(faceModel.id) + imageDao.getImagesByIds(imageIds) + } + + /** + * Get images for person as Flow (reactive). + */ + fun getImagesForPersonFlow(personId: String): Flow> { + return photoFaceTagDao.getImageIdsForFaceModelFlow(personId) + .map { imageIds -> + imageDao.getImagesByIds(imageIds) + } + } + + /** + * Get all persons with face models. + */ + suspend fun getPersonsWithFaceModels(): List = withContext(Dispatchers.IO) { + val faceModels = faceModelDao.getAllActiveFaceModels() + val personIds = faceModels.map { it.personId } + personDao.getPersonsByIds(personIds) + } + + /** + * Get face detection stats for a person. + */ + suspend fun getPersonFaceStats(personId: String): PersonFaceStats? = withContext(Dispatchers.IO) { + + val person = personDao.getPersonById(personId) ?: return@withContext null + val faceModel = faceModelDao.getFaceModelByPersonId(personId) ?: return@withContext null + + val imageIds = photoFaceTagDao.getImageIdsForFaceModel(faceModel.id) + val allTags = photoFaceTagDao.getAllTagsForFaceModel(faceModel.id) + + val avgConfidence = if (allTags.isNotEmpty()) { + allTags.map { it.confidence }.average().toFloat() + } else { + 0f + } + val lastDetected = allTags.maxOfOrNull { it.detectedAt } + + PersonFaceStats( + personId = person.id, + personName = person.name, + faceModelId = faceModel.id, + trainingImageCount = faceModel.trainingImageCount, + taggedPhotoCount = imageIds.size, + averageConfidence = avgConfidence, + lastDetectedAt = lastDetected + ) + } + + /** + * Get face tags for an image. + */ + suspend fun getFaceTagsForImage(imageId: String): List { + return photoFaceTagDao.getTagsForImage(imageId) + } + + /** + * Get person from a face tag. + */ + suspend fun getPersonForFaceTag(tag: PhotoFaceTagEntity): PersonEntity? = withContext(Dispatchers.IO) { + val faceModel = faceModelDao.getFaceModelById(tag.faceModelId) ?: return@withContext null + personDao.getPersonById(faceModel.personId) + } + + /** + * Get face tags with person info for an image. + */ + suspend fun getFaceTagsWithPersons(imageId: String): List> = withContext(Dispatchers.IO) { + val tags = photoFaceTagDao.getTagsForImage(imageId) + tags.mapNotNull { tag -> + val person = getPersonForFaceTag(tag) + if (person != null) tag to person else null + } + } + + // ====================== + // VERIFICATION / QUALITY + // ====================== + + suspend fun verifyFaceTag(tagId: String) { + photoFaceTagDao.markTagAsVerified(tagId) + } + + suspend fun getUnverifiedTags(): List { + return photoFaceTagDao.getUnverifiedTags() + } + + suspend fun getLowConfidenceTags(threshold: Float = 0.7f): List { + return photoFaceTagDao.getLowConfidenceTags(threshold) + } + + // ====================== + // MANAGEMENT + // ====================== + + suspend fun deleteFaceModel(faceModelId: String) = withContext(Dispatchers.IO) { + photoFaceTagDao.deleteTagsForFaceModel(faceModelId) + faceModelDao.deleteFaceModelById(faceModelId) + } + + suspend fun deleteTagsForImage(imageId: String) { + photoFaceTagDao.deleteTagsForImage(imageId) + } + + fun cleanup() { + faceNetModel.close() + } +} + +data class DetectedFace( + val croppedBitmap: Bitmap, + val boundingBox: android.graphics.Rect +) + +data class PersonFaceStats( + val personId: String, + val personName: String, + val faceModelId: String, + val trainingImageCount: Int, + val taggedPhotoCount: Int, + val averageConfidence: Float, + val lastDetectedAt: Long? +) \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/di/DatabaseModule.kt b/app/src/main/java/com/placeholder/sherpai2/di/DatabaseModule.kt index aa628ae..90acb2e 100644 --- a/app/src/main/java/com/placeholder/sherpai2/di/DatabaseModule.kt +++ b/app/src/main/java/com/placeholder/sherpai2/di/DatabaseModule.kt @@ -3,10 +3,7 @@ package com.placeholder.sherpai2.di import android.content.Context import androidx.room.Room import com.placeholder.sherpai2.data.local.AppDatabase -import com.placeholder.sherpai2.data.local.dao.ImageAggregateDao -import com.placeholder.sherpai2.data.local.dao.ImageEventDao -import com.placeholder.sherpai2.data.local.dao.ImageTagDao -import com.placeholder.sherpai2.data.local.dao.TagDao +import com.placeholder.sherpai2.data.local.dao.* import dagger.Module import dagger.Provides import dagger.hilt.InstallIn @@ -14,6 +11,14 @@ import dagger.hilt.android.qualifiers.ApplicationContext import dagger.hilt.components.SingletonComponent import javax.inject.Singleton +/** + * DatabaseModule - Provides database and DAOs + * + * FRESH START VERSION: + * - No migration needed + * - Uses fallbackToDestructiveMigration (deletes old database) + * - Perfect for development + */ @Module @InstallIn(SingletonComponent::class) object DatabaseModule { @@ -27,35 +32,77 @@ object DatabaseModule { context, AppDatabase::class.java, "sherpai.db" - ).build() + ) + .fallbackToDestructiveMigration() // ← Deletes old database, creates fresh + .build() } - // --- Add these DAO providers --- + // ===== YOUR EXISTING DAOs ===== + + @Provides + fun provideImageDao(database: AppDatabase): ImageDao { + return database.imageDao() + } @Provides fun provideTagDao(database: AppDatabase): TagDao { return database.tagDao() } + @Provides + fun provideEventDao(database: AppDatabase): EventDao { + return database.eventDao() + } + @Provides fun provideImageTagDao(database: AppDatabase): ImageTagDao { return database.imageTagDao() } - // Add providers for your other DAOs now to avoid future errors @Provides - fun provideImageDao(database: AppDatabase) = database.imageDao() + fun provideImagePersonDao(database: AppDatabase): ImagePersonDao { + return database.imagePersonDao() + } @Provides - fun providePersonDao(database: AppDatabase) = database.personDao() + fun provideImageEventDao(database: AppDatabase): ImageEventDao { + return database.imageEventDao() + } @Provides - fun provideEventDao(database: AppDatabase) = database.eventDao() + fun provideImageAggregateDao(database: AppDatabase): ImageAggregateDao { + return database.imageAggregateDao() + } + + // ===== NEW FACE RECOGNITION DAOs ===== @Provides - fun provideImageEventDao(database: AppDatabase): ImageEventDao = database.imageEventDao() + fun providePersonDao(database: AppDatabase): PersonDao { + return database.personDao() + } @Provides - fun provideImageAggregateDao(database: AppDatabase): ImageAggregateDao = database.imageAggregateDao() + fun provideFaceModelDao(database: AppDatabase): FaceModelDao { + return database.faceModelDao() + } + + @Provides + fun providePhotoFaceTagDao(database: AppDatabase): PhotoFaceTagDao { + return database.photoFaceTagDao() + } } +/** + * NOTES: + * + * fallbackToDestructiveMigration(): + * - Deletes database if schema changes + * - Creates fresh database with new schema + * - Perfect for development + * - ⚠️ Users lose data on updates + * + * For production later: + * - Remove fallbackToDestructiveMigration() + * - Add .addMigrations(MIGRATION_1_2, MIGRATION_2_3, ...) + * - This preserves user data + */ \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/di/Mlmodule.kt b/app/src/main/java/com/placeholder/sherpai2/di/Mlmodule.kt new file mode 100644 index 0000000..db3932f --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/di/Mlmodule.kt @@ -0,0 +1,34 @@ +package com.placeholder.sherpai2.di + +import android.content.Context +import com.placeholder.sherpai2.ml.FaceNetModel +import dagger.Module +import dagger.Provides +import dagger.hilt.InstallIn +import dagger.hilt.android.qualifiers.ApplicationContext +import dagger.hilt.components.SingletonComponent +import javax.inject.Singleton + +/** + * MLModule - Provides ML-related dependencies + * + * This module provides FaceNetModel for dependency injection + */ +@Module +@InstallIn(SingletonComponent::class) +object MLModule { + + /** + * Provide FaceNetModel singleton + * + * FaceNetModel loads the MobileFaceNet TFLite model and manages + * face embedding generation for recognition. + */ + @Provides + @Singleton + fun provideFaceNetModel( + @ApplicationContext context: Context + ): FaceNetModel { + return FaceNetModel(context) + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/di/RepositoryModule.kt b/app/src/main/java/com/placeholder/sherpai2/di/RepositoryModule.kt index 1cb8bfb..2e97da7 100644 --- a/app/src/main/java/com/placeholder/sherpai2/di/RepositoryModule.kt +++ b/app/src/main/java/com/placeholder/sherpai2/di/RepositoryModule.kt @@ -1,20 +1,35 @@ package com.placeholder.sherpai2.di - +import android.content.Context +import com.placeholder.sherpai2.data.local.dao.FaceModelDao +import com.placeholder.sherpai2.data.local.dao.ImageDao +import com.placeholder.sherpai2.data.local.dao.PersonDao +import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao +import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository import com.placeholder.sherpai2.data.repository.TaggingRepositoryImpl import com.placeholder.sherpai2.domain.repository.ImageRepository import com.placeholder.sherpai2.domain.repository.ImageRepositoryImpl import com.placeholder.sherpai2.domain.repository.TaggingRepository import dagger.Binds import dagger.Module +import dagger.Provides import dagger.hilt.InstallIn +import dagger.hilt.android.qualifiers.ApplicationContext import dagger.hilt.components.SingletonComponent import javax.inject.Singleton +/** + * RepositoryModule - Provides repository implementations + * + * UPDATED TO INCLUDE: + * - FaceRecognitionRepository for face recognition operations + */ @Module @InstallIn(SingletonComponent::class) abstract class RepositoryModule { + // ===== EXISTING REPOSITORY BINDINGS ===== + @Binds @Singleton abstract fun bindImageRepository( @@ -26,4 +41,50 @@ abstract class RepositoryModule { abstract fun bindTaggingRepository( impl: TaggingRepositoryImpl ): TaggingRepository -} + + // ===== COMPANION OBJECT FOR PROVIDES ===== + + companion object { + + /** + * Provide FaceRecognitionRepository + * + * Uses @Provides instead of @Binds because it needs Context parameter + * and multiple DAO dependencies + * + * INJECTED DEPENDENCIES: + * - Context: For FaceNetModel initialization + * - PersonDao: Access existing persons + * - ImageDao: Access existing images + * - FaceModelDao: Manage face models + * - PhotoFaceTagDao: Manage photo tags + * + * USAGE IN VIEWMODEL: + * ``` + * @HiltViewModel + * class MyViewModel @Inject constructor( + * private val faceRecognitionRepository: FaceRecognitionRepository + * ) : ViewModel() { + * // Use repository methods + * } + * ``` + */ + @Provides + @Singleton + fun provideFaceRecognitionRepository( + @ApplicationContext context: Context, + personDao: PersonDao, + imageDao: ImageDao, + faceModelDao: FaceModelDao, + photoFaceTagDao: PhotoFaceTagDao + ): FaceRecognitionRepository { + return FaceRecognitionRepository( + context = context, + personDao = personDao, + imageDao = imageDao, + faceModelDao = faceModelDao, + photoFaceTagDao = photoFaceTagDao + ) + } + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt b/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt new file mode 100644 index 0000000..22daad1 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt @@ -0,0 +1,204 @@ +package com.placeholder.sherpai2.ml + +import android.content.Context +import android.graphics.Bitmap +import org.tensorflow.lite.Interpreter +import java.io.FileInputStream +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.MappedByteBuffer +import java.nio.channels.FileChannel +import kotlin.math.sqrt + +/** + * FaceNetModel - MobileFaceNet wrapper for face recognition + * + * CLEAN IMPLEMENTATION: + * - All IDs are Strings (matching your schema) + * - Generates 192-dimensional embeddings + * - Cosine similarity for matching + */ +class FaceNetModel(private val context: Context) { + + companion object { + private const val MODEL_FILE = "mobilefacenet.tflite" + private const val INPUT_SIZE = 112 + private const val EMBEDDING_SIZE = 192 + + const val SIMILARITY_THRESHOLD_HIGH = 0.7f + const val SIMILARITY_THRESHOLD_MEDIUM = 0.6f + const val SIMILARITY_THRESHOLD_LOW = 0.5f + } + + private var interpreter: Interpreter? = null + + init { + try { + val model = loadModelFile() + interpreter = Interpreter(model) + } catch (e: Exception) { + throw RuntimeException("Failed to load FaceNet model", e) + } + } + + /** + * Load TFLite model from assets + */ + private fun loadModelFile(): MappedByteBuffer { + val fileDescriptor = context.assets.openFd(MODEL_FILE) + val inputStream = FileInputStream(fileDescriptor.fileDescriptor) + val fileChannel = inputStream.channel + val startOffset = fileDescriptor.startOffset + val declaredLength = fileDescriptor.declaredLength + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) + } + + /** + * Generate embedding for a single face + * + * @param faceBitmap Cropped face image (will be resized to 112x112) + * @return 192-dimensional embedding + */ + fun generateEmbedding(faceBitmap: Bitmap): FloatArray { + val resized = Bitmap.createScaledBitmap(faceBitmap, INPUT_SIZE, INPUT_SIZE, true) + val inputBuffer = preprocessImage(resized) + val output = Array(1) { FloatArray(EMBEDDING_SIZE) } + + interpreter?.run(inputBuffer, output) + + return normalizeEmbedding(output[0]) + } + + /** + * Generate embeddings for multiple faces (batch processing) + */ + fun generateEmbeddingsBatch( + faceBitmaps: List, + onProgress: (Int, Int) -> Unit = { _, _ -> } + ): List { + return faceBitmaps.mapIndexed { index, bitmap -> + onProgress(index + 1, faceBitmaps.size) + generateEmbedding(bitmap) + } + } + + /** + * Create person model by averaging multiple embeddings + */ + fun createPersonModel(embeddings: List): FloatArray { + require(embeddings.isNotEmpty()) { "Need at least one embedding" } + + val averaged = FloatArray(EMBEDDING_SIZE) { 0f } + + embeddings.forEach { embedding -> + for (i in embedding.indices) { + averaged[i] += embedding[i] + } + } + + val count = embeddings.size.toFloat() + for (i in averaged.indices) { + averaged[i] /= count + } + + return normalizeEmbedding(averaged) + } + + /** + * Calculate cosine similarity between two embeddings + * Returns value between -1.0 and 1.0 (higher = more similar) + */ + fun calculateSimilarity(embedding1: FloatArray, embedding2: FloatArray): Float { + require(embedding1.size == EMBEDDING_SIZE && embedding2.size == EMBEDDING_SIZE) { + "Invalid embedding size" + } + + var dotProduct = 0f + var norm1 = 0f + var norm2 = 0f + + for (i in embedding1.indices) { + dotProduct += embedding1[i] * embedding2[i] + norm1 += embedding1[i] * embedding1[i] + norm2 += embedding2[i] * embedding2[i] + } + + return dotProduct / (sqrt(norm1) * sqrt(norm2)) + } + + /** + * Find best matching face model from a list + * + * @param faceEmbedding Embedding to match + * @param modelEmbeddings List of (modelId: String, embedding: FloatArray) + * @param threshold Minimum similarity threshold + * @return Pair of (modelId: String, confidence: Float) or null + */ + fun findBestMatch( + faceEmbedding: FloatArray, + modelEmbeddings: List>, + threshold: Float = SIMILARITY_THRESHOLD_HIGH + ): Pair? { + var bestMatch: Pair? = null + var highestSimilarity = threshold + + for ((modelId, modelEmbedding) in modelEmbeddings) { + val similarity = calculateSimilarity(faceEmbedding, modelEmbedding) + + if (similarity > highestSimilarity) { + highestSimilarity = similarity + bestMatch = Pair(modelId, similarity) + } + } + + return bestMatch + } + + /** + * Preprocess image for model input + */ + private fun preprocessImage(bitmap: Bitmap): ByteBuffer { + val buffer = ByteBuffer.allocateDirect(4 * INPUT_SIZE * INPUT_SIZE * 3) + buffer.order(ByteOrder.nativeOrder()) + + val pixels = IntArray(INPUT_SIZE * INPUT_SIZE) + bitmap.getPixels(pixels, 0, INPUT_SIZE, 0, 0, INPUT_SIZE, INPUT_SIZE) + + for (pixel in pixels) { + val r = ((pixel shr 16) and 0xFF) / 255.0f + val g = ((pixel shr 8) and 0xFF) / 255.0f + val b = (pixel and 0xFF) / 255.0f + + buffer.putFloat((r - 0.5f) / 0.5f) + buffer.putFloat((g - 0.5f) / 0.5f) + buffer.putFloat((b - 0.5f) / 0.5f) + } + + return buffer + } + + /** + * Normalize embedding to unit length + */ + private fun normalizeEmbedding(embedding: FloatArray): FloatArray { + var norm = 0f + for (value in embedding) { + norm += value * value + } + norm = sqrt(norm) + + return if (norm > 0) { + FloatArray(embedding.size) { i -> embedding[i] / norm } + } else { + embedding + } + } + + /** + * Clean up resources + */ + fun close() { + interpreter?.close() + interpreter = null + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/devscreens/DummyScreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/devscreens/DummyScreen.kt index c453471..9693bc2 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/devscreens/DummyScreen.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/devscreens/DummyScreen.kt @@ -1,17 +1,162 @@ package com.placeholder.sherpai2.ui.devscreens +import androidx.compose.foundation.background import androidx.compose.foundation.layout.* -import androidx.compose.material3.Text +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.* +import androidx.compose.material3.* import androidx.compose.runtime.Composable import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.graphics.Brush +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.unit.dp +/** + * Beautiful placeholder screen for features under development + * + * Shows: + * - Feature name + * - Description + * - "Coming Soon" indicator + * - Consistent styling with rest of app + */ @Composable -fun DummyScreen(label: String) { +fun DummyScreen( + title: String, + subtitle: String = "This feature is under development" +) { Box( - modifier = Modifier.fillMaxSize(), + modifier = Modifier + .fillMaxSize() + .background( + Brush.verticalGradient( + colors = listOf( + MaterialTheme.colorScheme.surface, + MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f) + ) + ) + ), contentAlignment = Alignment.Center ) { - Text(label) + Column( + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.spacedBy(24.dp), + modifier = Modifier.padding(48.dp) + ) { + + // Icon badge + Surface( + modifier = Modifier.size(96.dp), + shape = RoundedCornerShape(24.dp), + color = MaterialTheme.colorScheme.primaryContainer, + shadowElevation = 8.dp + ) { + Box(contentAlignment = Alignment.Center) { + Icon( + Icons.Default.Construction, + contentDescription = null, + modifier = Modifier.size(48.dp), + tint = MaterialTheme.colorScheme.primary + ) + } + } + + Spacer(modifier = Modifier.height(8.dp)) + + // Title + Text( + text = title, + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold, + textAlign = TextAlign.Center + ) + + // Subtitle + Text( + text = subtitle, + style = MaterialTheme.typography.bodyLarge, + color = MaterialTheme.colorScheme.onSurfaceVariant, + textAlign = TextAlign.Center, + modifier = Modifier.padding(horizontal = 24.dp) + ) + + Spacer(modifier = Modifier.height(8.dp)) + + // Coming soon badge + Surface( + shape = RoundedCornerShape(16.dp), + color = MaterialTheme.colorScheme.tertiaryContainer, + shadowElevation = 2.dp + ) { + Row( + modifier = Modifier.padding(horizontal = 20.dp, vertical = 12.dp), + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + Icons.Default.Schedule, + contentDescription = null, + modifier = Modifier.size(20.dp), + tint = MaterialTheme.colorScheme.onTertiaryContainer + ) + Text( + text = "Coming Soon", + style = MaterialTheme.typography.labelLarge, + fontWeight = FontWeight.SemiBold, + color = MaterialTheme.colorScheme.onTertiaryContainer + ) + } + } + + Spacer(modifier = Modifier.height(24.dp)) + + // Feature preview card + Card( + modifier = Modifier.fillMaxWidth(0.8f), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + ), + shape = RoundedCornerShape(16.dp) + ) { + Column( + modifier = Modifier.padding(20.dp), + verticalArrangement = Arrangement.spacedBy(12.dp) + ) { + Text( + text = "What's planned:", + style = MaterialTheme.typography.titleSmall, + fontWeight = FontWeight.Bold + ) + + FeatureItem("Full implementation") + FeatureItem("Beautiful UI design") + FeatureItem("Smooth animations") + FeatureItem("Production-ready code") + } + } + } } } + +@Composable +private fun FeatureItem(text: String) { + Row( + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + Icons.Default.CheckCircle, + contentDescription = null, + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.primary + ) + Text( + text = text, + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/modelinventory/Personinventoryscreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/modelinventory/Personinventoryscreen.kt new file mode 100644 index 0000000..0668180 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/modelinventory/Personinventoryscreen.kt @@ -0,0 +1,614 @@ +package com.placeholder.sherpai2.ui.modelinventory + +import androidx.compose.foundation.background +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.* +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.graphics.vector.ImageVector +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.style.TextOverflow +import androidx.compose.ui.unit.dp +import androidx.hilt.navigation.compose.hiltViewModel +import java.text.SimpleDateFormat +import java.util.* + +/** + * PersonInventoryScreen - Manage trained face models + * + * Features: + * - List all trained persons + * - View stats + * - DELETE models + * - SCAN LIBRARY to find person in all photos (NEW!) + */ +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun PersonInventoryScreen( + modifier: Modifier = Modifier, + viewModel: PersonInventoryViewModel = hiltViewModel(), + onViewPersonPhotos: (String) -> Unit = {} +) { + val uiState by viewModel.uiState.collectAsState() + val scanningState by viewModel.scanningState.collectAsState() + + var personToDelete by remember { mutableStateOf(null) } + var personToScan by remember { mutableStateOf(null) } + + Scaffold( + topBar = { + TopAppBar( + title = { Text("Trained People") }, + colors = TopAppBarDefaults.topAppBarColors( + containerColor = MaterialTheme.colorScheme.primaryContainer + ), + actions = { + IconButton(onClick = { viewModel.loadPersons() }) { + Icon(Icons.Default.Refresh, contentDescription = "Refresh") + } + } + ) + } + ) { paddingValues -> + Box( + modifier = modifier + .fillMaxSize() + .padding(paddingValues) + ) { + when (val state = uiState) { + is PersonInventoryViewModel.InventoryUiState.Loading -> { + LoadingView() + } + + is PersonInventoryViewModel.InventoryUiState.Success -> { + if (state.persons.isEmpty()) { + EmptyView() + } else { + PersonListView( + persons = state.persons, + onDeleteClick = { personToDelete = it }, + onScanClick = { personToScan = it }, + onViewPhotos = { onViewPersonPhotos(it.person.id) }, + scanningState = scanningState + ) + } + } + + is PersonInventoryViewModel.InventoryUiState.Error -> { + ErrorView( + message = state.message, + onRetry = { viewModel.loadPersons() } + ) + } + } + + // Scanning overlay + if (scanningState is PersonInventoryViewModel.ScanningState.Scanning) { + ScanningOverlay(scanningState as PersonInventoryViewModel.ScanningState.Scanning) + } + } + } + + // Delete confirmation dialog + personToDelete?.let { personWithStats -> + AlertDialog( + onDismissRequest = { personToDelete = null }, + title = { Text("Delete ${personWithStats.person.name}?") }, + text = { + Text( + "This will delete the face model and all ${personWithStats.stats.taggedPhotoCount} " + + "face tags. Your photos will NOT be deleted." + ) + }, + confirmButton = { + TextButton( + onClick = { + viewModel.deletePerson( + personWithStats.person.id, + personWithStats.stats.faceModelId + ) + personToDelete = null + }, + colors = ButtonDefaults.textButtonColors( + contentColor = MaterialTheme.colorScheme.error + ) + ) { + Text("Delete") + } + }, + dismissButton = { + TextButton(onClick = { personToDelete = null }) { + Text("Cancel") + } + } + ) + } + + // Scan library confirmation dialog + personToScan?.let { personWithStats -> + AlertDialog( + onDismissRequest = { personToScan = null }, + icon = { Icon(Icons.Default.Search, contentDescription = null) }, + title = { Text("Scan Library for ${personWithStats.person.name}?") }, + text = { + Column(verticalArrangement = Arrangement.spacedBy(12.dp)) { + Text( + "This will scan your entire photo library and automatically tag " + + "all photos containing ${personWithStats.person.name}." + ) + Text( + "Currently tagged: ${personWithStats.stats.taggedPhotoCount} photos", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + }, + confirmButton = { + Button( + onClick = { + viewModel.scanLibraryForPerson( + personWithStats.person.id, + personWithStats.stats.faceModelId + ) + personToScan = null + } + ) { + Icon(Icons.Default.Search, contentDescription = null) + Spacer(modifier = Modifier.width(8.dp)) + Text("Start Scan") + } + }, + dismissButton = { + TextButton(onClick = { personToScan = null }) { + Text("Cancel") + } + } + ) + } +} + +@Composable +private fun LoadingView() { + Box( + modifier = Modifier.fillMaxSize(), + contentAlignment = Alignment.Center + ) { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + CircularProgressIndicator() + Text( + text = "Loading trained models...", + style = MaterialTheme.typography.bodyMedium + ) + } + } +} + +@Composable +private fun EmptyView() { + Box( + modifier = Modifier.fillMaxSize(), + contentAlignment = Alignment.Center + ) { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.spacedBy(16.dp), + modifier = Modifier.padding(32.dp) + ) { + Icon( + Icons.Default.Face, + contentDescription = null, + modifier = Modifier.size(64.dp), + tint = MaterialTheme.colorScheme.primary.copy(alpha = 0.5f) + ) + Text( + text = "No trained people yet", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Bold + ) + Text( + text = "Train a person using 10+ photos to start recognizing faces", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } +} + +@Composable +private fun ErrorView( + message: String, + onRetry: () -> Unit +) { + Box( + modifier = Modifier.fillMaxSize(), + contentAlignment = Alignment.Center + ) { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.spacedBy(16.dp), + modifier = Modifier.padding(32.dp) + ) { + Icon( + Icons.Default.Warning, + contentDescription = null, + modifier = Modifier.size(64.dp), + tint = MaterialTheme.colorScheme.error + ) + Text( + text = "Error", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Bold + ) + Text( + text = message, + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + Button(onClick = onRetry) { + Icon(Icons.Default.Refresh, contentDescription = null) + Spacer(modifier = Modifier.width(8.dp)) + Text("Retry") + } + } + } +} + +@Composable +private fun PersonListView( + persons: List, + onDeleteClick: (PersonInventoryViewModel.PersonWithStats) -> Unit, + onScanClick: (PersonInventoryViewModel.PersonWithStats) -> Unit, + onViewPhotos: (PersonInventoryViewModel.PersonWithStats) -> Unit, + scanningState: PersonInventoryViewModel.ScanningState +) { + LazyColumn( + contentPadding = PaddingValues(16.dp), + verticalArrangement = Arrangement.spacedBy(12.dp) + ) { + // Summary card + item { + SummaryCard(totalPersons = persons.size) + Spacer(modifier = Modifier.height(8.dp)) + } + + // Person cards + items(persons) { personWithStats -> + PersonCard( + personWithStats = personWithStats, + onDeleteClick = { onDeleteClick(personWithStats) }, + onScanClick = { onScanClick(personWithStats) }, + onViewPhotos = { onViewPhotos(personWithStats) }, + isScanning = scanningState is PersonInventoryViewModel.ScanningState.Scanning && + scanningState.personId == personWithStats.person.id + ) + } + } +} + +@Composable +private fun SummaryCard(totalPersons: Int) { + Card( + modifier = Modifier.fillMaxWidth(), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.primaryContainer + ) + ) { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(16.dp), + horizontalArrangement = Arrangement.spacedBy(16.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + Icons.Default.Face, + contentDescription = null, + modifier = Modifier.size(48.dp), + tint = MaterialTheme.colorScheme.primary + ) + Column { + Text( + text = "$totalPersons trained ${if (totalPersons == 1) "person" else "people"}", + style = MaterialTheme.typography.titleLarge, + fontWeight = FontWeight.Bold + ) + Text( + text = "Face recognition models ready", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f) + ) + } + } + } +} + +@Composable +private fun PersonCard( + personWithStats: PersonInventoryViewModel.PersonWithStats, + onDeleteClick: () -> Unit, + onScanClick: () -> Unit, + onViewPhotos: () -> Unit, + isScanning: Boolean +) { + val stats = personWithStats.stats + + Card( + modifier = Modifier.fillMaxWidth(), + elevation = CardDefaults.cardElevation(defaultElevation = 2.dp) + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(16.dp) + ) { + // Header: Name and actions + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Row( + horizontalArrangement = Arrangement.spacedBy(12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Box( + modifier = Modifier + .size(48.dp) + .clip(CircleShape) + .background(MaterialTheme.colorScheme.primary), + contentAlignment = Alignment.Center + ) { + Text( + text = personWithStats.person.name.take(1).uppercase(), + style = MaterialTheme.typography.titleLarge, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.onPrimary + ) + } + + Column { + Text( + text = personWithStats.person.name, + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Bold, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + Text( + text = "ID: ${personWithStats.person.id.take(8)}", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + + IconButton(onClick = onDeleteClick) { + Icon( + Icons.Default.Delete, + contentDescription = "Delete", + tint = MaterialTheme.colorScheme.error + ) + } + } + + Spacer(modifier = Modifier.height(16.dp)) + + // Stats grid + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceEvenly + ) { + StatItem( + icon = Icons.Default.PhotoCamera, + label = "Training", + value = "${stats.trainingImageCount}" + ) + StatItem( + icon = Icons.Default.AccountBox, + label = "Tagged", + value = "${stats.taggedPhotoCount}" + ) + StatItem( + icon = Icons.Default.CheckCircle, + label = "Confidence", + value = "${(stats.averageConfidence * 100).toInt()}%", + valueColor = if (stats.averageConfidence >= 0.8f) { + MaterialTheme.colorScheme.primary + } else if (stats.averageConfidence >= 0.6f) { + MaterialTheme.colorScheme.tertiary + } else { + MaterialTheme.colorScheme.error + } + ) + } + + Spacer(modifier = Modifier.height(16.dp)) + + // Last detected + stats.lastDetectedAt?.let { timestamp -> + Surface( + modifier = Modifier.fillMaxWidth(), + color = MaterialTheme.colorScheme.surfaceVariant, + shape = RoundedCornerShape(8.dp) + ) { + Row( + modifier = Modifier.padding(12.dp), + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + Icons.Default.DateRange, + contentDescription = null, + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + text = "Last detected: ${formatDate(timestamp)}", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } + + Spacer(modifier = Modifier.height(12.dp)) + + // Action buttons row + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { + // Scan Library button (PRIMARY ACTION) + Button( + onClick = onScanClick, + modifier = Modifier.weight(1f), + enabled = !isScanning, + colors = ButtonDefaults.buttonColors( + containerColor = MaterialTheme.colorScheme.primary + ) + ) { + if (isScanning) { + CircularProgressIndicator( + modifier = Modifier.size(16.dp), + color = MaterialTheme.colorScheme.onPrimary, + strokeWidth = 2.dp + ) + } else { + Icon( + Icons.Default.Search, + contentDescription = null, + modifier = Modifier.size(18.dp) + ) + } + Spacer(modifier = Modifier.width(8.dp)) + Text(if (isScanning) "Scanning..." else "Scan Library") + } + + // View photos button + if (stats.taggedPhotoCount > 0) { + OutlinedButton( + onClick = onViewPhotos, + modifier = Modifier.weight(1f) + ) { + Icon( + Icons.Default.Photo, + contentDescription = null, + modifier = Modifier.size(18.dp) + ) + Spacer(modifier = Modifier.width(8.dp)) + Text("View (${stats.taggedPhotoCount})") + } + } + } + } + } +} + +@Composable +private fun StatItem( + icon: ImageVector, + label: String, + value: String, + valueColor: Color = MaterialTheme.colorScheme.primary +) { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.spacedBy(4.dp) + ) { + Icon( + icon, + contentDescription = null, + modifier = Modifier.size(24.dp), + tint = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + text = value, + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Bold, + color = valueColor + ) + Text( + text = label, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } +} + +/** + * Scanning overlay showing progress + */ +@Composable +private fun ScanningOverlay(state: PersonInventoryViewModel.ScanningState.Scanning) { + Box( + modifier = Modifier + .fillMaxSize() + .background(MaterialTheme.colorScheme.surface.copy(alpha = 0.95f)), + contentAlignment = Alignment.Center + ) { + Card( + modifier = Modifier + .fillMaxWidth(0.85f) + .padding(24.dp), + elevation = CardDefaults.cardElevation(defaultElevation = 8.dp) + ) { + Column( + modifier = Modifier.padding(24.dp), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + Icon( + Icons.Default.Search, + contentDescription = null, + modifier = Modifier.size(48.dp), + tint = MaterialTheme.colorScheme.primary + ) + + Text( + text = "Scanning Library", + style = MaterialTheme.typography.titleLarge, + fontWeight = FontWeight.Bold + ) + + Text( + text = "Finding ${state.personName} in your photos...", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + LinearProgressIndicator( + progress = { state.progress / state.total.toFloat() }, + modifier = Modifier.fillMaxWidth(), + ) + + Text( + text = "${state.progress} / ${state.total} photos scanned", + style = MaterialTheme.typography.bodySmall + ) + + Text( + text = "${state.facesFound} faces detected", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.primary + ) + } + } + } +} + +private fun formatDate(timestamp: Long): String { + val formatter = SimpleDateFormat("MMM d, yyyy h:mm a", Locale.getDefault()) + return formatter.format(Date(timestamp)) +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/modelinventory/Personinventoryviewmodel.kt b/app/src/main/java/com/placeholder/sherpai2/ui/modelinventory/Personinventoryviewmodel.kt new file mode 100644 index 0000000..1753852 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/modelinventory/Personinventoryviewmodel.kt @@ -0,0 +1,299 @@ +package com.placeholder.sherpai2.ui.modelinventory + +import android.app.Application +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import android.net.Uri +import androidx.lifecycle.AndroidViewModel +import androidx.lifecycle.viewModelScope +import com.google.mlkit.vision.common.InputImage +import com.google.mlkit.vision.face.FaceDetection +import com.google.mlkit.vision.face.FaceDetectorOptions +import com.placeholder.sherpai2.data.local.entity.PersonEntity +import com.placeholder.sherpai2.data.repository.DetectedFace +import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository +import com.placeholder.sherpai2.data.repository.PersonFaceStats +import com.placeholder.sherpai2.domain.repository.ImageRepository +import dagger.hilt.android.lifecycle.HiltViewModel +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import kotlinx.coroutines.tasks.await +import javax.inject.Inject + +/** + * PersonInventoryViewModel - Manage trained face models + * + * Features: + * - List all trained persons with stats + * - Delete models + * - SCAN LIBRARY to find person in all photos + * - View sample photos + */ +@HiltViewModel +class PersonInventoryViewModel @Inject constructor( + application: Application, + private val faceRecognitionRepository: FaceRecognitionRepository, + private val imageRepository: ImageRepository +) : AndroidViewModel(application) { + + private val _uiState = MutableStateFlow(InventoryUiState.Loading) + val uiState: StateFlow = _uiState.asStateFlow() + + private val _scanningState = MutableStateFlow(ScanningState.Idle) + val scanningState: StateFlow = _scanningState.asStateFlow() + + // ML Kit face detector + private val faceDetector by lazy { + val options = FaceDetectorOptions.Builder() + .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) + .setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_NONE) + .setClassificationMode(FaceDetectorOptions.CLASSIFICATION_MODE_NONE) + .setMinFaceSize(0.15f) + .build() + FaceDetection.getClient(options) + } + + data class PersonWithStats( + val person: PersonEntity, + val stats: PersonFaceStats + ) + + sealed class InventoryUiState { + object Loading : InventoryUiState() + data class Success(val persons: List) : InventoryUiState() + data class Error(val message: String) : InventoryUiState() + } + + sealed class ScanningState { + object Idle : ScanningState() + data class Scanning( + val personId: String, + val personName: String, + val progress: Int, + val total: Int, + val facesFound: Int + ) : ScanningState() + data class Complete( + val personName: String, + val facesFound: Int, + val imagesScanned: Int + ) : ScanningState() + } + + init { + loadPersons() + } + + /** + * Load all trained persons with their stats + */ + fun loadPersons() { + viewModelScope.launch { + try { + _uiState.value = InventoryUiState.Loading + + val persons = faceRecognitionRepository.getPersonsWithFaceModels() + + val personsWithStats = persons.mapNotNull { person -> + val stats = faceRecognitionRepository.getPersonFaceStats(person.id) + if (stats != null) { + PersonWithStats(person, stats) + } else { + null + } + }.sortedByDescending { it.stats.taggedPhotoCount } + + _uiState.value = InventoryUiState.Success(personsWithStats) + + } catch (e: Exception) { + _uiState.value = InventoryUiState.Error( + e.message ?: "Failed to load persons" + ) + } + } + } + + /** + * Delete a face model + */ + fun deletePerson(personId: String, faceModelId: String) { + viewModelScope.launch { + try { + faceRecognitionRepository.deleteFaceModel(faceModelId) + loadPersons() // Refresh list + } catch (e: Exception) { + _uiState.value = InventoryUiState.Error( + "Failed to delete: ${e.message}" + ) + } + } + } + + /** + * Scan entire photo library for a specific person + * + * Process: + * 1. Get all images from library + * 2. For each image: + * - Detect faces using ML Kit + * - Generate embeddings for detected faces + * - Compare to person's face model + * - Create PhotoFaceTagEntity if match found + * 3. Update progress throughout + */ + fun scanLibraryForPerson(personId: String, faceModelId: String) { + viewModelScope.launch { + try { + // Get person name for UI + val currentState = _uiState.value + val person = if (currentState is InventoryUiState.Success) { + currentState.persons.find { it.person.id == personId }?.person + } else null + + val personName = person?.name ?: "Unknown" + + // Get all images from library + val allImages = imageRepository.getAllImages().first() + val totalImages = allImages.size + + _scanningState.value = ScanningState.Scanning( + personId = personId, + personName = personName, + progress = 0, + total = totalImages, + facesFound = 0 + ) + + var facesFound = 0 + + // Scan each image + allImages.forEachIndexed { index, imageWithEverything -> + val image = imageWithEverything.image + + // Detect faces in this image + val detectedFaces = detectFacesInImage(image.imageUri) + + if (detectedFaces.isNotEmpty()) { + // Scan this image for the person + val tags = faceRecognitionRepository.scanImage( + imageId = image.imageId, + detectedFaces = detectedFaces, + threshold = 0.6f // Slightly lower threshold for library scanning + ) + + // Count how many faces matched this person + val matchingTags = tags.filter { tag -> + // Check if this tag belongs to our target person's face model + tag.faceModelId == faceModelId + } + + facesFound += matchingTags.size + } + + // Update progress + _scanningState.value = ScanningState.Scanning( + personId = personId, + personName = personName, + progress = index + 1, + total = totalImages, + facesFound = facesFound + ) + } + + // Scan complete + _scanningState.value = ScanningState.Complete( + personName = personName, + facesFound = facesFound, + imagesScanned = totalImages + ) + + // Refresh the list to show updated counts + loadPersons() + + // Reset scanning state after 3 seconds + delay(3000) + _scanningState.value = ScanningState.Idle + + } catch (e: Exception) { + _scanningState.value = ScanningState.Idle + _uiState.value = InventoryUiState.Error( + "Scan failed: ${e.message}" + ) + } + } + } + + /** + * Detect faces in an image using ML Kit + * + * @param imageUri URI of the image to scan + * @return List of detected faces with cropped bitmaps + */ + private suspend fun detectFacesInImage(imageUri: String): List = withContext(Dispatchers.Default) { + try { + // Load bitmap from URI + val uri = Uri.parse(imageUri) + val inputStream = getApplication().contentResolver.openInputStream(uri) + val bitmap = BitmapFactory.decodeStream(inputStream) + inputStream?.close() + + if (bitmap == null) return@withContext emptyList() + + // Create ML Kit input image + val image = InputImage.fromBitmap(bitmap, 0) + + // Detect faces (await the Task) + val faces = faceDetector.process(image).await() + + // Convert to DetectedFace objects + faces.mapNotNull { face -> + val boundingBox = face.boundingBox + + // Crop face from bitmap with bounds checking + val croppedFace = try { + val left = boundingBox.left.coerceAtLeast(0) + val top = boundingBox.top.coerceAtLeast(0) + val width = boundingBox.width().coerceAtMost(bitmap.width - left) + val height = boundingBox.height().coerceAtMost(bitmap.height - top) + + if (width > 0 && height > 0) { + Bitmap.createBitmap(bitmap, left, top, width, height) + } else { + null + } + } catch (e: Exception) { + null + } + + if (croppedFace != null) { + DetectedFace( + croppedBitmap = croppedFace, + boundingBox = boundingBox + ) + } else { + null + } + } + + } catch (e: Exception) { + emptyList() + } + } + + /** + * Get sample images for a person + */ + suspend fun getPersonImages(personId: String) = + faceRecognitionRepository.getImagesForPerson(personId) + + override fun onCleared() { + super.onCleared() + faceDetector.close() + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppDestinations.kt b/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppDestinations.kt index 57305ce..f1ca03f 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppDestinations.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppDestinations.kt @@ -1,46 +1,157 @@ package com.placeholder.sherpai2.ui.navigation import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.automirrored.filled.Label import androidx.compose.material.icons.filled.* import androidx.compose.ui.graphics.vector.ImageVector /** - * Drawer-only metadata. + * AppDestinations - Navigation metadata for drawer UI * - * These objects: - * - Drive the drawer UI - * - Provide labels and icons - * - Map cleanly to navigation routes + * Clean, organized structure: + * - Routes for navigation + * - Icons for visual identity + * - Labels for display + * - Descriptions for clarity + * - Grouped by function */ sealed class AppDestinations( val route: String, val icon: ImageVector, - val label: String + val label: String, + val description: String = "" ) { - object Tour : AppDestinations(AppRoutes.TOUR, Icons.Default.PhotoLibrary, "Tour") - object Search : AppDestinations(AppRoutes.SEARCH, Icons.Default.Search, "Search") - object Models : AppDestinations(AppRoutes.MODELS, Icons.Default.Layers, "Models") - object Inventory : AppDestinations(AppRoutes.INVENTORY, Icons.Default.Inventory2, "Inv") - object Train : AppDestinations(AppRoutes.TRAIN, Icons.Default.TrackChanges, "Train") - object Tags : AppDestinations(AppRoutes.TAGS, Icons.Default.LocalOffer, "Tags") - object ImageDetails : AppDestinations(AppRoutes.IMAGE_DETAIL, Icons.Default.LocalOffer, "IMAGE_DETAIL") + // ================== + // PHOTO BROWSING + // ================== - object Upload : AppDestinations(AppRoutes.UPLOAD, Icons.Default.CloudUpload, "Upload") - object Settings : AppDestinations(AppRoutes.SETTINGS, Icons.Default.Settings, "Settings") + data object Search : AppDestinations( + route = AppRoutes.SEARCH, + icon = Icons.Default.Search, + label = "Search", + description = "Find photos by tag or person" + ) + + data object Tour : AppDestinations( + route = AppRoutes.TOUR, + icon = Icons.Default.Place, + label = "Tour", + description = "Browse by location & time" + ) + + // ImageDetail is not in drawer (internal navigation only) + + // ================== + // FACE RECOGNITION + // ================== + + data object Inventory : AppDestinations( + route = AppRoutes.INVENTORY, + icon = Icons.Default.Face, + label = "People", + description = "Trained face models" + ) + + data object Train : AppDestinations( + route = AppRoutes.TRAIN, + icon = Icons.Default.ModelTraining, + label = "Train", + description = "Train new person" + ) + + data object Models : AppDestinations( + route = AppRoutes.MODELS, + icon = Icons.Default.SmartToy, + label = "Models", + description = "AI model management" + ) + + // ================== + // ORGANIZATION + // ================== + + data object Tags : AppDestinations( + route = AppRoutes.TAGS, + icon = Icons.AutoMirrored.Filled.Label, + label = "Tags", + description = "Manage photo tags" + ) + + data object Upload : AppDestinations( + route = AppRoutes.UPLOAD, + icon = Icons.Default.UploadFile, + label = "Upload", + description = "Add new photos" + ) + + // ================== + // SETTINGS + // ================== + + data object Settings : AppDestinations( + route = AppRoutes.SETTINGS, + icon = Icons.Default.Settings, + label = "Settings", + description = "App preferences" + ) } -val mainDrawerItems = listOf( - AppDestinations.Tour, +/** + * Organized destination groups for beautiful drawer sections + */ + +// Photo browsing section +val photoDestinations = listOf( AppDestinations.Search, - AppDestinations.Models, - AppDestinations.Inventory, - AppDestinations.Train, - AppDestinations.Tags, - AppDestinations.ImageDetails + AppDestinations.Tour ) -val utilityDrawerItems = listOf( - AppDestinations.Upload, - AppDestinations.Settings +// Face recognition section +val faceRecognitionDestinations = listOf( + AppDestinations.Inventory, + AppDestinations.Train, + AppDestinations.Models ) + +// Organization section +val organizationDestinations = listOf( + AppDestinations.Tags, + AppDestinations.Upload +) + +// Settings (separate, pinned to bottom) +val settingsDestination = AppDestinations.Settings + +/** + * All drawer items (excludes Settings which is handled separately) + */ +val allMainDrawerDestinations = photoDestinations + faceRecognitionDestinations + organizationDestinations + +/** + * Helper function to get destination by route + * Useful for highlighting current route in drawer + */ +fun getDestinationByRoute(route: String?): AppDestinations? { + return when (route) { + AppRoutes.SEARCH -> AppDestinations.Search + AppRoutes.TOUR -> AppDestinations.Tour + AppRoutes.INVENTORY -> AppDestinations.Inventory + AppRoutes.TRAIN -> AppDestinations.Train + AppRoutes.MODELS -> AppDestinations.Models + AppRoutes.TAGS -> AppDestinations.Tags + AppRoutes.UPLOAD -> AppDestinations.Upload + AppRoutes.SETTINGS -> AppDestinations.Settings + else -> null + } +} + +/** + * Legacy support (for backwards compatibility) + * These match your old structure + */ +@Deprecated("Use organized groups instead", ReplaceWith("allMainDrawerDestinations")) +val mainDrawerItems = allMainDrawerDestinations + +@Deprecated("Use settingsDestination instead", ReplaceWith("listOf(settingsDestination)")) +val utilityDrawerItems = listOf(settingsDestination) \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppNavHost.kt b/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppNavHost.kt index 3bac704..3b562fa 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppNavHost.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/navigation/AppNavHost.kt @@ -2,9 +2,11 @@ package com.placeholder.sherpai2.ui.navigation import android.net.Uri import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.collectAsState +import androidx.compose.runtime.getValue import androidx.compose.ui.Modifier import androidx.hilt.navigation.compose.hiltViewModel -import androidx.lifecycle.ViewModel import androidx.navigation.NavHostController import androidx.navigation.NavType import androidx.navigation.compose.NavHost @@ -12,25 +14,34 @@ import androidx.navigation.compose.composable import androidx.navigation.navArgument import com.placeholder.sherpai2.ui.devscreens.DummyScreen import com.placeholder.sherpai2.ui.imagedetail.ImageDetailScreen +import com.placeholder.sherpai2.ui.modelinventory.PersonInventoryScreen import com.placeholder.sherpai2.ui.search.SearchScreen import com.placeholder.sherpai2.ui.search.SearchViewModel -import java.net.URLDecoder -import java.net.URLEncoder -import com.placeholder.sherpai2.ui.tour.TourViewModel import com.placeholder.sherpai2.ui.tour.TourScreen +import com.placeholder.sherpai2.ui.tour.TourViewModel import com.placeholder.sherpai2.ui.trainingprep.ImageSelectorScreen -import com.placeholder.sherpai2.ui.trainingprep.TrainingScreen -import com.placeholder.sherpai2.ui.navigation.AppRoutes -import com.placeholder.sherpai2.ui.navigation.AppRoutes.ScanResultsScreen +import com.placeholder.sherpai2.ui.trainingprep.ScanResultsScreen import com.placeholder.sherpai2.ui.trainingprep.ScanningState import com.placeholder.sherpai2.ui.trainingprep.TrainViewModel -import androidx.compose.runtime.LaunchedEffect -import androidx.compose.runtime.collectAsState -import androidx.compose.runtime.getValue -import com.placeholder.sherpai2.ui.trainingprep.ScanResultsScreen - - +import com.placeholder.sherpai2.ui.trainingprep.TrainingScreen +import java.net.URLDecoder +import java.net.URLEncoder +/** + * AppNavHost - Main navigation graph + * + * Complete flow: + * - Photo browsing (Search, Tour, Detail) + * - Face recognition (Inventory, Train) + * - Organization (Tags, Upload) + * - Settings + * + * Features: + * - URL encoding for safe navigation + * - Proper back stack management + * - State preservation + * - Beautiful placeholders + */ @Composable fun AppNavHost( navController: NavHostController, @@ -42,20 +53,29 @@ fun AppNavHost( modifier = modifier ) { - /** SEARCH SCREEN **/ + // ========================================== + // PHOTO BROWSING + // ========================================== + + /** + * SEARCH SCREEN + * Main photo browser with face tag search + */ composable(AppRoutes.SEARCH) { val searchViewModel: SearchViewModel = hiltViewModel() SearchScreen( searchViewModel = searchViewModel, onImageClick = { imageUri -> - // Encode the URI to safely pass as argument val encodedUri = URLEncoder.encode(imageUri, "UTF-8") navController.navigate("${AppRoutes.IMAGE_DETAIL}/$encodedUri") } ) } - /** IMAGE DETAIL SCREEN **/ + /** + * IMAGE DETAIL SCREEN + * Single photo view with metadata + */ composable( route = "${AppRoutes.IMAGE_DETAIL}/{imageUri}", arguments = listOf( @@ -64,8 +84,6 @@ fun AppNavHost( } ) ) { backStackEntry -> - - // Decode URI to restore original value val imageUri = backStackEntry.arguments?.getString("imageUri") ?.let { URLDecoder.decode(it, "UTF-8") } ?: error("imageUri missing from navigation") @@ -76,70 +94,160 @@ fun AppNavHost( ) } + /** + * TOUR SCREEN + * Browse photos by location and time + */ composable(AppRoutes.TOUR) { val tourViewModel: TourViewModel = hiltViewModel() TourScreen( tourViewModel = tourViewModel, onImageClick = { imageUri -> - navController.navigate("${AppRoutes.IMAGE_DETAIL}/$imageUri") + val encodedUri = URLEncoder.encode(imageUri, "UTF-8") + navController.navigate("${AppRoutes.IMAGE_DETAIL}/$encodedUri") } ) } - /** TRAINING FLOW **/ + // ========================================== + // FACE RECOGNITION SYSTEM + // ========================================== + + /** + * PERSON INVENTORY SCREEN + * View all trained face models + * + * Features: + * - List all trained people + * - Show stats (training count, tagged photos, confidence) + * - Delete models + * - View photos containing each person + */ + composable(AppRoutes.INVENTORY) { + PersonInventoryScreen( + onViewPersonPhotos = { personId -> + // Navigate back to search + // TODO: In future, add person filter to search screen + navController.navigate(AppRoutes.SEARCH) + } + ) + } + + /** + * TRAINING FLOW + * Train new face recognition model + * + * Flow: + * 1. TrainingScreen (select images button) + * 2. ImageSelectorScreen (pick 10+ photos) + * 3. ScanResultsScreen (validation + name input) + * 4. Training completes → navigate to Inventory + */ composable(AppRoutes.TRAIN) { entry -> val trainViewModel: TrainViewModel = hiltViewModel() val uiState by trainViewModel.uiState.collectAsState() - // Observe the result from the ImageSelector + // Get images selected from ImageSelector val selectedUris = entry.savedStateHandle.get>("selected_image_uris") - // If we have new URIs and we are currently Idle, start scanning + // Start scanning when new images are selected LaunchedEffect(selectedUris) { if (selectedUris != null && uiState is ScanningState.Idle) { trainViewModel.scanAndTagFaces(selectedUris) - // Clear the handle so it doesn't re-trigger on configuration change entry.savedStateHandle.remove>("selected_image_uris") } } - if (uiState is ScanningState.Idle) { - // Initial state: Show start button or prompt - TrainingScreen( - onSelectImages = { navController.navigate(AppRoutes.IMAGE_SELECTOR) } - ) - } else { - // Processing or Success state: Show the results screen - ScanResultsScreen( - state = uiState, - onFinish = { - navController.navigate(AppRoutes.SEARCH) { - popUpTo(AppRoutes.TRAIN) { inclusive = true } + when (uiState) { + is ScanningState.Idle -> { + // Show start screen with "Select Images" button + TrainingScreen( + onSelectImages = { + navController.navigate(AppRoutes.IMAGE_SELECTOR) } - } - ) + ) + } + else -> { + // Show validation results and training UI + ScanResultsScreen( + state = uiState, + onFinish = { + // After training, go to inventory to see new person + navController.navigate(AppRoutes.INVENTORY) { + popUpTo(AppRoutes.TRAIN) { inclusive = true } + } + } + ) + } } } + /** + * IMAGE SELECTOR SCREEN + * Pick images for training (internal screen) + */ composable(AppRoutes.IMAGE_SELECTOR) { ImageSelectorScreen( onImagesSelected = { uris -> + // Pass selected URIs back to Train screen navController.previousBackStackEntry ?.savedStateHandle ?.set("selected_image_uris", uris) - navController.popBackStack() } ) } - /** DUMMY SCREENS FOR OTHER DRAWER ITEMS **/ - //composable(AppRoutes.TOUR) { DummyScreen("Tour (stub)") } - composable(AppRoutes.MODELS) { DummyScreen("Models (stub)") } - composable(AppRoutes.INVENTORY) { DummyScreen("Inventory (stub)") } - //composable(AppRoutes.TRAIN) { DummyScreen("Train (stub)") } - composable(AppRoutes.TAGS) { DummyScreen("Tags (stub)") } - composable(AppRoutes.UPLOAD) { DummyScreen("Upload (stub)") } - composable(AppRoutes.SETTINGS) { DummyScreen("Settings (stub)") } + /** + * MODELS SCREEN + * AI model management (placeholder) + */ + composable(AppRoutes.MODELS) { + DummyScreen( + title = "AI Models", + subtitle = "Manage face recognition models" + ) + } + + // ========================================== + // ORGANIZATION + // ========================================== + + /** + * TAGS SCREEN + * Manage photo tags (placeholder) + */ + composable(AppRoutes.TAGS) { + DummyScreen( + title = "Tags", + subtitle = "Organize your photos with tags" + ) + } + + /** + * UPLOAD SCREEN + * Import new photos (placeholder) + */ + composable(AppRoutes.UPLOAD) { + DummyScreen( + title = "Upload", + subtitle = "Add photos to your library" + ) + } + + // ========================================== + // SETTINGS + // ========================================== + + /** + * SETTINGS SCREEN + * App preferences (placeholder) + */ + composable(AppRoutes.SETTINGS) { + DummyScreen( + title = "Settings", + subtitle = "App preferences and configuration" + ) + } } -} +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/presentation/AppDrawerContent.kt b/app/src/main/java/com/placeholder/sherpai2/ui/presentation/AppDrawerContent.kt index 9e5644c..fc0c3b6 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/presentation/AppDrawerContent.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/presentation/AppDrawerContent.kt @@ -1,85 +1,243 @@ package com.placeholder.sherpai2.ui.presentation +import androidx.compose.foundation.background import androidx.compose.foundation.layout.* +import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material3.* -import androidx.compose.material3.DividerDefaults import androidx.compose.runtime.Composable +import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.Brush +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.unit.dp import androidx.compose.material.icons.Icons import androidx.compose.material.icons.automirrored.filled.Label import androidx.compose.material.icons.automirrored.filled.List import androidx.compose.material.icons.filled.* -import androidx.compose.material3.HorizontalDivider import com.placeholder.sherpai2.ui.navigation.AppRoutes +/** + * Beautiful app drawer with sections, gradient header, and polish + */ @OptIn(ExperimentalMaterial3Api::class) @Composable fun AppDrawerContent( currentRoute: String?, onDestinationClicked: (String) -> Unit ) { - // Drawer sheet with fixed width - ModalDrawerSheet(modifier = Modifier.width(280.dp)) { + ModalDrawerSheet( + modifier = Modifier.width(300.dp), + drawerContainerColor = MaterialTheme.colorScheme.surface + ) { + Column(modifier = Modifier.fillMaxSize()) { - // Header / Logo - Text( - "SherpAI Control Panel", - style = MaterialTheme.typography.headlineSmall, - modifier = Modifier.padding(16.dp) - ) + // ===== BEAUTIFUL GRADIENT HEADER ===== + Box( + modifier = Modifier + .fillMaxWidth() + .background( + Brush.verticalGradient( + colors = listOf( + MaterialTheme.colorScheme.primaryContainer, + MaterialTheme.colorScheme.surface + ) + ) + ) + .padding(24.dp) + ) { + Column( + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + // App icon/logo area + Surface( + modifier = Modifier.size(56.dp), + shape = RoundedCornerShape(16.dp), + color = MaterialTheme.colorScheme.primary, + shadowElevation = 4.dp + ) { + Box(contentAlignment = Alignment.Center) { + Icon( + Icons.Default.Face, + contentDescription = null, + modifier = Modifier.size(32.dp), + tint = MaterialTheme.colorScheme.onPrimary + ) + } + } - HorizontalDivider( - Modifier.fillMaxWidth(), - thickness = DividerDefaults.Thickness, - color = DividerDefaults.color - ) + Text( + "SherpAI", + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.onSurface + ) - // Main drawer items - val mainItems = listOf( - Triple(AppRoutes.SEARCH, "Search", Icons.Default.Search), - Triple(AppRoutes.TOUR, "Tour", Icons.Default.Place), - Triple(AppRoutes.MODELS, "Models", Icons.Default.ModelTraining), - Triple(AppRoutes.INVENTORY, "Inventory", Icons.AutoMirrored.Filled.List), - Triple(AppRoutes.TRAIN, "Train", Icons.Default.Train), - Triple(AppRoutes.TAGS, "Tags", Icons.AutoMirrored.Filled.Label) - ) - - Column(modifier = Modifier.padding(vertical = 8.dp)) { - mainItems.forEach { (route, label, icon) -> - NavigationDrawerItem( - label = { Text(label) }, - icon = { Icon(icon, contentDescription = label) }, - selected = route == currentRoute, - onClick = { onDestinationClicked(route) }, - modifier = Modifier.padding(NavigationDrawerItemDefaults.ItemPadding) - ) + Text( + "Face Recognition System", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } } - } - Divider( - Modifier - .fillMaxWidth() - .padding(vertical = 8.dp), - thickness = DividerDefaults.Thickness - ) + Spacer(modifier = Modifier.height(8.dp)) - // Utility items - val utilityItems = listOf( - Triple(AppRoutes.UPLOAD, "Upload", Icons.Default.UploadFile), - Triple(AppRoutes.SETTINGS, "Settings", Icons.Default.Settings) - ) + // ===== NAVIGATION SECTIONS ===== + Column( + modifier = Modifier + .fillMaxWidth() + .weight(1f) + .padding(horizontal = 12.dp), + verticalArrangement = Arrangement.spacedBy(4.dp) + ) { - Column(modifier = Modifier.padding(vertical = 8.dp)) { - utilityItems.forEach { (route, label, icon) -> - NavigationDrawerItem( - label = { Text(label) }, - icon = { Icon(icon, contentDescription = label) }, - selected = route == currentRoute, - onClick = { onDestinationClicked(route) }, - modifier = Modifier.padding(NavigationDrawerItemDefaults.ItemPadding) + // Photos Section + DrawerSection(title = "Photos") + + val photoItems = listOf( + DrawerItem(AppRoutes.SEARCH, "Search", Icons.Default.Search, "Find photos by tag or person"), + DrawerItem(AppRoutes.TOUR, "Tour", Icons.Default.Place, "Browse by location & time") ) + + photoItems.forEach { item -> + DrawerNavigationItem( + item = item, + selected = item.route == currentRoute, + onClick = { onDestinationClicked(item.route) } + ) + } + + Spacer(modifier = Modifier.height(8.dp)) + + // Face Recognition Section + DrawerSection(title = "Face Recognition") + + val faceItems = listOf( + DrawerItem(AppRoutes.INVENTORY, "People", Icons.Default.Face, "Trained face models"), + DrawerItem(AppRoutes.TRAIN, "Train", Icons.Default.ModelTraining, "Train new person"), + DrawerItem(AppRoutes.MODELS, "Models", Icons.Default.SmartToy, "AI model management") + ) + + faceItems.forEach { item -> + DrawerNavigationItem( + item = item, + selected = item.route == currentRoute, + onClick = { onDestinationClicked(item.route) } + ) + } + + Spacer(modifier = Modifier.height(8.dp)) + + // Organization Section + DrawerSection(title = "Organization") + + val orgItems = listOf( + DrawerItem(AppRoutes.TAGS, "Tags", Icons.AutoMirrored.Filled.Label, "Manage photo tags"), + DrawerItem(AppRoutes.UPLOAD, "Upload", Icons.Default.UploadFile, "Add new photos") + ) + + orgItems.forEach { item -> + DrawerNavigationItem( + item = item, + selected = item.route == currentRoute, + onClick = { onDestinationClicked(item.route) } + ) + } + + Spacer(modifier = Modifier.weight(1f)) + + // Settings at bottom + HorizontalDivider( + modifier = Modifier.padding(vertical = 8.dp), + color = MaterialTheme.colorScheme.outlineVariant + ) + + DrawerNavigationItem( + item = DrawerItem( + AppRoutes.SETTINGS, + "Settings", + Icons.Default.Settings, + "App preferences" + ), + selected = AppRoutes.SETTINGS == currentRoute, + onClick = { onDestinationClicked(AppRoutes.SETTINGS) } + ) + + Spacer(modifier = Modifier.height(8.dp)) } } } } + +/** + * Section header in drawer + */ +@Composable +private fun DrawerSection(title: String) { + Text( + text = title, + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary, + modifier = Modifier.padding(horizontal = 16.dp, vertical = 8.dp) + ) +} + +/** + * Individual navigation item with icon, label, and subtitle + */ +@Composable +private fun DrawerNavigationItem( + item: DrawerItem, + selected: Boolean, + onClick: () -> Unit +) { + NavigationDrawerItem( + label = { + Column(verticalArrangement = Arrangement.spacedBy(2.dp)) { + Text( + text = item.label, + style = MaterialTheme.typography.bodyLarge, + fontWeight = if (selected) FontWeight.SemiBold else FontWeight.Normal + ) + item.subtitle?.let { + Text( + text = it, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f) + ) + } + } + }, + icon = { + Icon( + item.icon, + contentDescription = item.label, + modifier = Modifier.size(24.dp) + ) + }, + selected = selected, + onClick = onClick, + modifier = Modifier + .padding(NavigationDrawerItemDefaults.ItemPadding) + .clip(RoundedCornerShape(12.dp)), + colors = NavigationDrawerItemDefaults.colors( + selectedContainerColor = MaterialTheme.colorScheme.primaryContainer, + selectedIconColor = MaterialTheme.colorScheme.primary, + selectedTextColor = MaterialTheme.colorScheme.onPrimaryContainer, + unselectedContainerColor = Color.Transparent + ) + ) +} + +/** + * Data class for drawer items + */ +private data class DrawerItem( + val route: String, + val label: String, + val icon: androidx.compose.ui.graphics.vector.ImageVector, + val subtitle: String? = null +) \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/search/SearchScreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/search/SearchScreen.kt index 35d3956..cd1bc5c 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/search/SearchScreen.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/search/SearchScreen.kt @@ -1,25 +1,34 @@ package com.placeholder.sherpai2.ui.search +import androidx.compose.foundation.background import androidx.compose.foundation.layout.* import androidx.compose.foundation.lazy.grid.* +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.* import androidx.compose.material3.* import androidx.compose.runtime.* +import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.Brush +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp import androidx.lifecycle.compose.collectAsStateWithLifecycle import com.placeholder.sherpai2.ui.search.components.ImageGridItem -import com.placeholder.sherpai2.ui.search.SearchViewModel /** - * SearchScreen + * Beautiful SearchScreen with face tag display * - * Purpose: - * - Validate tag-based queries - * - Preview matching images - * - * This is NOT final UX. - * It is a diagnostic surface. + * Polish improvements: + * - Gradient header + * - Better stats card + * - Smooth animations + * - Enhanced visual hierarchy */ +@OptIn(ExperimentalMaterial3Api::class) @Composable fun SearchScreen( modifier: Modifier = Modifier, @@ -29,42 +38,368 @@ fun SearchScreen( var query by remember { mutableStateOf("") } - /** - * Reactive result set. - * Updates whenever: - * - query changes - * - database changes - */ val images by searchViewModel .searchImagesByTag(query) .collectAsStateWithLifecycle(initialValue = emptyList()) - Column( - modifier = modifier - .fillMaxSize() - .padding(12.dp) - ) { + Scaffold( + topBar = { + // Gradient header + Box( + modifier = Modifier + .fillMaxWidth() + .background( + Brush.verticalGradient( + colors = listOf( + MaterialTheme.colorScheme.primaryContainer, + MaterialTheme.colorScheme.surface + ) + ) + ) + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(16.dp) + ) { + // Title + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(12.dp) + ) { + Surface( + shape = RoundedCornerShape(12.dp), + color = MaterialTheme.colorScheme.primary, + shadowElevation = 2.dp, + modifier = Modifier.size(48.dp) + ) { + Box(contentAlignment = Alignment.Center) { + Icon( + Icons.Default.Search, + contentDescription = null, + tint = MaterialTheme.colorScheme.onPrimary, + modifier = Modifier.size(28.dp) + ) + } + } - OutlinedTextField( - value = query, - onValueChange = { query = it }, - label = { Text("Search by tag") }, - modifier = Modifier.fillMaxWidth(), - singleLine = true - ) + Column { + Text( + text = "Search Photos", + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold + ) + Text( + text = "Find by tag or person", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } - Spacer(modifier = Modifier.height(12.dp)) + Spacer(modifier = Modifier.height(16.dp)) - LazyVerticalGrid( - columns = GridCells.Adaptive(120.dp), - contentPadding = PaddingValues(4.dp), - verticalArrangement = Arrangement.spacedBy(4.dp), - horizontalArrangement = Arrangement.spacedBy(4.dp), - modifier = Modifier.fillMaxSize() + // Search bar + OutlinedTextField( + value = query, + onValueChange = { query = it }, + label = { Text("Search by tag") }, + leadingIcon = { + Icon(Icons.Default.Search, contentDescription = null) + }, + trailingIcon = { + if (query.isNotEmpty()) { + IconButton(onClick = { query = "" }) { + Icon(Icons.Default.Clear, contentDescription = "Clear") + } + } + }, + modifier = Modifier.fillMaxWidth(), + singleLine = true, + shape = RoundedCornerShape(16.dp), + colors = OutlinedTextFieldDefaults.colors( + focusedContainerColor = MaterialTheme.colorScheme.surface, + unfocusedContainerColor = MaterialTheme.colorScheme.surface + ) + ) + } + } + } + ) { paddingValues -> + Column( + modifier = modifier + .fillMaxSize() + .padding(paddingValues) ) { - items(images) { imageWithEverything -> - ImageGridItem(image = imageWithEverything.image) + // Stats bar + if (images.isNotEmpty()) { + StatsBar(images = images) + } + + // Results grid + if (images.isEmpty() && query.isBlank()) { + EmptySearchState() + } else if (images.isEmpty() && query.isNotBlank()) { + NoResultsState(query = query) + } else { + LazyVerticalGrid( + columns = GridCells.Adaptive(120.dp), + contentPadding = PaddingValues(12.dp), + verticalArrangement = Arrangement.spacedBy(12.dp), + horizontalArrangement = Arrangement.spacedBy(12.dp), + modifier = Modifier.fillMaxSize() + ) { + items( + items = images, + key = { it.image.imageId } + ) { imageWithFaceTags -> + ImageWithFaceTagsCard( + imageWithFaceTags = imageWithFaceTags, + onImageClick = onImageClick + ) + } + } } } } } + +/** + * Pretty stats bar showing results summary + */ +@Composable +private fun StatsBar(images: List) { + val totalFaces = images.sumOf { it.faceTags.size } + val uniquePersons = images.flatMap { it.persons }.distinctBy { it.id }.size + + Surface( + modifier = Modifier + .fillMaxWidth() + .padding(12.dp), + color = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.5f), + shape = RoundedCornerShape(16.dp), + shadowElevation = 2.dp + ) { + Row( + modifier = Modifier.padding(16.dp), + horizontalArrangement = Arrangement.SpaceEvenly, + verticalAlignment = Alignment.CenterVertically + ) { + StatBadge( + icon = Icons.Default.Photo, + label = "Images", + value = images.size.toString() + ) + + VerticalDivider( + modifier = Modifier.height(40.dp), + color = MaterialTheme.colorScheme.outline.copy(alpha = 0.3f) + ) + + StatBadge( + icon = Icons.Default.Face, + label = "Faces", + value = totalFaces.toString() + ) + + if (uniquePersons > 0) { + VerticalDivider( + modifier = Modifier.height(40.dp), + color = MaterialTheme.colorScheme.outline.copy(alpha = 0.3f) + ) + + StatBadge( + icon = Icons.Default.People, + label = "People", + value = uniquePersons.toString() + ) + } + } + } +} + +@Composable +private fun StatBadge( + icon: androidx.compose.ui.graphics.vector.ImageVector, + label: String, + value: String +) { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.spacedBy(4.dp) + ) { + Icon( + icon, + contentDescription = null, + modifier = Modifier.size(24.dp), + tint = MaterialTheme.colorScheme.primary + ) + Text( + text = value, + style = MaterialTheme.typography.titleLarge, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary + ) + Text( + text = label, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } +} + +/** + * Empty state when no search query + */ +@Composable +private fun EmptySearchState() { + Box( + modifier = Modifier.fillMaxSize(), + contentAlignment = Alignment.Center + ) { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.spacedBy(16.dp), + modifier = Modifier.padding(32.dp) + ) { + Icon( + Icons.Default.Search, + contentDescription = null, + modifier = Modifier.size(80.dp), + tint = MaterialTheme.colorScheme.primary.copy(alpha = 0.3f) + ) + Text( + text = "Search your photos", + style = MaterialTheme.typography.titleLarge, + fontWeight = FontWeight.Bold + ) + Text( + text = "Enter a tag to find photos", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } +} + +/** + * No results state + */ +@Composable +private fun NoResultsState(query: String) { + Box( + modifier = Modifier.fillMaxSize(), + contentAlignment = Alignment.Center + ) { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.spacedBy(16.dp), + modifier = Modifier.padding(32.dp) + ) { + Icon( + Icons.Default.SearchOff, + contentDescription = null, + modifier = Modifier.size(80.dp), + tint = MaterialTheme.colorScheme.error.copy(alpha = 0.5f) + ) + Text( + text = "No results", + style = MaterialTheme.typography.titleLarge, + fontWeight = FontWeight.Bold + ) + Text( + text = "No photos found for \"$query\"", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } +} + +/** + * Beautiful card showing image with face tags + */ +@Composable +private fun ImageWithFaceTagsCard( + imageWithFaceTags: ImageWithFaceTags, + onImageClick: (String) -> Unit +) { + Card( + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(16.dp), + elevation = CardDefaults.cardElevation(defaultElevation = 4.dp) + ) { + Column( + modifier = Modifier.fillMaxWidth() + ) { + // Image + ImageGridItem( + image = imageWithFaceTags.image, + onClick = { onImageClick(imageWithFaceTags.image.imageId) } + ) + + // Face tags + if (imageWithFaceTags.persons.isNotEmpty()) { + Surface( + color = MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.5f), + modifier = Modifier.fillMaxWidth() + ) { + Column( + modifier = Modifier.padding(8.dp), + verticalArrangement = Arrangement.spacedBy(4.dp) + ) { + imageWithFaceTags.persons.take(3).forEachIndexed { index, person -> + Row( + horizontalArrangement = Arrangement.spacedBy(6.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + Icons.Default.Face, + contentDescription = null, + modifier = Modifier.size(16.dp), + tint = MaterialTheme.colorScheme.primary + ) + Text( + text = person.name, + style = MaterialTheme.typography.bodySmall, + fontWeight = FontWeight.Medium, + maxLines = 1, + overflow = TextOverflow.Ellipsis, + modifier = Modifier.weight(1f) + ) + + if (index < imageWithFaceTags.faceTags.size) { + val confidence = (imageWithFaceTags.faceTags[index].confidence * 100).toInt() + Surface( + shape = RoundedCornerShape(8.dp), + color = if (confidence >= 80) { + MaterialTheme.colorScheme.primary.copy(alpha = 0.2f) + } else { + MaterialTheme.colorScheme.tertiary.copy(alpha = 0.2f) + } + ) { + Text( + text = "$confidence%", + style = MaterialTheme.typography.labelSmall, + modifier = Modifier.padding(horizontal = 6.dp, vertical = 2.dp), + fontWeight = FontWeight.Bold + ) + } + } + } + } + + if (imageWithFaceTags.persons.size > 3) { + Text( + text = "+${imageWithFaceTags.persons.size - 3} more", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.primary, + fontWeight = FontWeight.Medium + ) + } + } + } + } + } + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/search/SearchViewModel.kt b/app/src/main/java/com/placeholder/sherpai2/ui/search/SearchViewModel.kt index 8e7aa8e..42845fe 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/search/SearchViewModel.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/search/SearchViewModel.kt @@ -1,24 +1,70 @@ package com.placeholder.sherpai2.ui.search import androidx.lifecycle.ViewModel +import com.placeholder.sherpai2.data.local.entity.ImageEntity +import com.placeholder.sherpai2.data.local.entity.PersonEntity +import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity +import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository import com.placeholder.sherpai2.domain.repository.ImageRepository import dagger.hilt.android.lifecycle.HiltViewModel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map import javax.inject.Inject /** * SearchViewModel * - * Stateless except for query-driven flows. + * CLEAN IMPLEMENTATION: + * - Properly handles Flow types + * - Fetches face tags for each image + * - Returns combined data structure */ @HiltViewModel class SearchViewModel @Inject constructor( - private val imageRepository: ImageRepository + private val imageRepository: ImageRepository, + private val faceRecognitionRepository: FaceRecognitionRepository ) : ViewModel() { - fun searchImagesByTag(tag: String) = - if (tag.isBlank()) { + /** + * Search images by tag with face recognition data. + * + * RETURNS: Flow> + * Each image includes its detected faces and person names + */ + fun searchImagesByTag(tag: String): Flow> { + val imagesFlow = if (tag.isBlank()) { imageRepository.getAllImages() } else { imageRepository.findImagesByTag(tag) } + + // Transform Flow to include face recognition data + return imagesFlow.map { imagesList -> + imagesList.map { imageWithEverything -> + // Get face tags with person info for this image + val tagsWithPersons = faceRecognitionRepository.getFaceTagsWithPersons( + imageWithEverything.image.imageId + ) + + ImageWithFaceTags( + image = imageWithEverything.image, + faceTags = tagsWithPersons.map { it.first }, + persons = tagsWithPersons.map { it.second } + ) + } + } + } } + +/** + * Data class containing image with face recognition data + * + * @property image The image entity + * @property faceTags Face tags detected in this image + * @property persons Person entities (parallel to faceTags) + */ +data class ImageWithFaceTags( + val image: ImageEntity, + val faceTags: List, + val persons: List +) \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/ScanResultsScreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/ScanResultsScreen.kt index d30ee99..f1ca09b 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/ScanResultsScreen.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/ScanResultsScreen.kt @@ -8,12 +8,13 @@ import androidx.compose.foundation.BorderStroke import androidx.compose.foundation.Image import androidx.compose.foundation.background import androidx.compose.foundation.border -import androidx.compose.foundation.clickable import androidx.compose.foundation.layout.* import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.itemsIndexed import androidx.compose.foundation.shape.CircleShape import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.text.KeyboardActions +import androidx.compose.foundation.text.KeyboardOptions import androidx.compose.material.icons.Icons import androidx.compose.material.icons.filled.* import androidx.compose.material3.* @@ -25,6 +26,8 @@ import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.asImageBitmap import androidx.compose.ui.layout.ContentScale import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.input.ImeAction +import androidx.compose.ui.text.input.KeyboardCapitalization import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.unit.dp import androidx.hilt.navigation.compose.hiltViewModel @@ -38,6 +41,28 @@ fun ScanResultsScreen( trainViewModel: TrainViewModel = hiltViewModel() ) { var showFacePickerDialog by remember { mutableStateOf(null) } + var showNameInputDialog by remember { mutableStateOf(false) } + + // Observe training state + val trainingState by trainViewModel.trainingState.collectAsState() + + // Handle training state changes + LaunchedEffect(trainingState) { + when (trainingState) { + is TrainingState.Success -> { + // Training completed successfully + val success = trainingState as TrainingState.Success + // You can show a success message or navigate away + // For now, we'll just reset and finish + trainViewModel.resetTrainingState() + onFinish() + } + is TrainingState.Error -> { + // Error will be shown in dialog, no action needed here + } + else -> { /* Idle or Processing */ } + } + } Scaffold( topBar = { @@ -69,7 +94,10 @@ fun ScanResultsScreen( is ScanningState.Success -> { ImprovedResultsView( result = state.sanityCheckResult, - onContinue = onFinish, + onContinue = { + // Show name input dialog instead of immediately finishing + showNameInputDialog = true + }, onRetry = onFinish, onReplaceImage = { oldUri, newUri -> trainViewModel.replaceImage(oldUri, newUri) @@ -87,6 +115,11 @@ fun ScanResultsScreen( ) } } + + // Show training overlay if processing + if (trainingState is TrainingState.Processing) { + TrainingOverlay(trainingState = trainingState as TrainingState.Processing) + } } } @@ -101,6 +134,185 @@ fun ScanResultsScreen( } ) } + + // Name Input Dialog + if (showNameInputDialog) { + NameInputDialog( + onDismiss = { showNameInputDialog = false }, + onConfirm = { name -> + showNameInputDialog = false + trainViewModel.createFaceModel(name) + }, + trainingState = trainingState + ) + } +} + +/** + * Dialog for entering person's name before training + */ +@OptIn(ExperimentalMaterial3Api::class) +@Composable +private fun NameInputDialog( + onDismiss: () -> Unit, + onConfirm: (String) -> Unit, + trainingState: TrainingState +) { + var personName by remember { mutableStateOf("") } + val isError = trainingState is TrainingState.Error + + AlertDialog( + onDismissRequest = { + if (trainingState !is TrainingState.Processing) { + onDismiss() + } + }, + title = { + Text( + text = if (isError) "Training Error" else "Who is this?", + style = MaterialTheme.typography.headlineSmall + ) + }, + text = { + Column( + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + if (isError) { + // Show error message + val error = trainingState as TrainingState.Error + Surface( + color = MaterialTheme.colorScheme.errorContainer, + shape = RoundedCornerShape(8.dp) + ) { + Row( + modifier = Modifier.padding(12.dp), + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + Icons.Default.Warning, + contentDescription = null, + tint = MaterialTheme.colorScheme.error + ) + Text( + text = error.message, + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onErrorContainer + ) + } + } + } else { + Text( + text = "Enter the name of the person in these training images. This will help you find their photos later.", + style = MaterialTheme.typography.bodyMedium + ) + } + + OutlinedTextField( + value = personName, + onValueChange = { personName = it }, + label = { Text("Person's Name") }, + placeholder = { Text("e.g., John Doe") }, + singleLine = true, + enabled = trainingState !is TrainingState.Processing, + keyboardOptions = KeyboardOptions( + capitalization = KeyboardCapitalization.Words, + imeAction = ImeAction.Done + ), + keyboardActions = KeyboardActions( + onDone = { + if (personName.isNotBlank()) { + onConfirm(personName.trim()) + } + } + ), + modifier = Modifier.fillMaxWidth() + ) + } + }, + confirmButton = { + Button( + onClick = { onConfirm(personName.trim()) }, + enabled = personName.isNotBlank() && trainingState !is TrainingState.Processing + ) { + if (trainingState is TrainingState.Processing) { + CircularProgressIndicator( + modifier = Modifier.size(16.dp), + strokeWidth = 2.dp, + color = MaterialTheme.colorScheme.onPrimary + ) + Spacer(modifier = Modifier.width(8.dp)) + } + Text(if (isError) "Try Again" else "Start Training") + } + }, + dismissButton = { + if (trainingState !is TrainingState.Processing) { + TextButton(onClick = onDismiss) { + Text("Cancel") + } + } + } + ) +} + +/** + * Overlay shown during training process + */ +@Composable +private fun TrainingOverlay(trainingState: TrainingState.Processing) { + Box( + modifier = Modifier + .fillMaxSize() + .background(Color.Black.copy(alpha = 0.7f)), + contentAlignment = Alignment.Center + ) { + Card( + modifier = Modifier + .padding(32.dp) + .fillMaxWidth(0.9f), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.surface + ) + ) { + Column( + modifier = Modifier.padding(24.dp), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + CircularProgressIndicator( + modifier = Modifier.size(64.dp), + strokeWidth = 6.dp + ) + + Text( + text = "Creating Face Model", + style = MaterialTheme.typography.titleLarge, + fontWeight = FontWeight.Bold + ) + + Text( + text = trainingState.stage, + style = MaterialTheme.typography.bodyMedium, + textAlign = TextAlign.Center, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + if (trainingState.total > 0) { + LinearProgressIndicator( + progress = { (trainingState.progress.toFloat() / trainingState.total.toFloat()).coerceIn(0f, 1f) }, + modifier = Modifier.fillMaxWidth() + ) + + Text( + text = "${trainingState.progress} / ${trainingState.total}", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } + } } @Composable @@ -579,7 +791,7 @@ private fun ValidationIssuesCard(errors: List when (error) { diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/TrainViewModel.kt b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/TrainViewModel.kt index f987c8c..e890485 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/TrainViewModel.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/TrainViewModel.kt @@ -5,6 +5,8 @@ import android.graphics.Bitmap import android.net.Uri import androidx.lifecycle.AndroidViewModel import androidx.lifecycle.viewModelScope +import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository +import com.placeholder.sherpai2.ml.FaceNetModel import dagger.hilt.android.lifecycle.HiltViewModel import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow @@ -12,6 +14,9 @@ import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.launch import javax.inject.Inject +/** + * State for image scanning and validation + */ sealed class ScanningState { object Idle : ScanningState() data class Processing(val progress: Int, val total: Int) : ScanningState() @@ -21,17 +26,44 @@ sealed class ScanningState { data class Error(val message: String) : ScanningState() } +/** + * State for face model training/creation + */ +sealed class TrainingState { + object Idle : TrainingState() + data class Processing(val stage: String, val progress: Int, val total: Int) : TrainingState() + data class Success(val personName: String, val personId: String) : TrainingState() + data class Error(val message: String) : TrainingState() +} + +/** + * ViewModel for training face recognition models + * + * WORKFLOW: + * 1. User selects 10+ images → scanAndTagFaces() + * 2. Images validated → Success state with validImagesWithFaces + * 3. User can replace images or pick faces from group photos + * 4. When ready → createFaceModel(personName) + * 5. Creates PersonEntity + FaceModelEntity in database + */ @HiltViewModel class TrainViewModel @Inject constructor( - application: Application + application: Application, + private val faceRecognitionRepository: FaceRecognitionRepository, + private val faceNetModel: FaceNetModel ) : AndroidViewModel(application) { private val sanityChecker = TrainingSanityChecker(application) private val faceDetectionHelper = FaceDetectionHelper(application) + // Scanning/validation state private val _uiState = MutableStateFlow(ScanningState.Idle) val uiState: StateFlow = _uiState.asStateFlow() + // Training/model creation state + private val _trainingState = MutableStateFlow(TrainingState.Idle) + val trainingState: StateFlow = _trainingState.asStateFlow() + // Keep track of current images for replacements private var currentImageUris: List = emptyList() @@ -43,8 +75,101 @@ class TrainViewModel @Inject constructor( val croppedFaceBitmap: Bitmap ) + // ====================== + // FACE MODEL CREATION + // ====================== + /** - * Scan and validate images for training + * Create face model from validated training images. + * + * COMPLETE PROCESS: + * 1. Verify we have 10+ validated images + * 2. Call repository to create PersonEntity + FaceModelEntity + * 3. Repository handles: embedding generation, averaging, database save + * + * Call this when user clicks "Continue to Training" after validation passes. + * + * @param personName Name for the new person + * + * EXAMPLE USAGE IN UI: + * if (result.isValid) { + * showNameDialog { name -> + * trainViewModel.createFaceModel(name) + * } + * } + */ + fun createFaceModel(personName: String) { + val currentState = _uiState.value + if (currentState !is ScanningState.Success) { + _trainingState.value = TrainingState.Error("No validated images available") + return + } + + val validImages = currentState.sanityCheckResult.validImagesWithFaces + if (validImages.size < 10) { + _trainingState.value = TrainingState.Error("Need at least 10 valid images, have ${validImages.size}") + return + } + + viewModelScope.launch { + try { + _trainingState.value = TrainingState.Processing( + stage = "Creating person and training model", + progress = 0, + total = validImages.size + ) + + // Repository handles everything: + // - Creates PersonEntity in 'persons' table + // - Generates embeddings from face bitmaps + // - Averages embeddings + // - Creates FaceModelEntity linked to PersonEntity + val personId = faceRecognitionRepository.createPersonWithFaceModel( + personName = personName, + validImages = validImages, + onProgress = { current, total -> + _trainingState.value = TrainingState.Processing( + stage = "Processing image $current/$total", + progress = current, + total = total + ) + } + ) + + _trainingState.value = TrainingState.Success( + personName = personName, + personId = personId + ) + + } catch (e: Exception) { + _trainingState.value = TrainingState.Error( + e.message ?: "Failed to create face model" + ) + } + } + } + + /** + * Reset training state back to idle. + * Call this after handling success/error. + */ + fun resetTrainingState() { + _trainingState.value = TrainingState.Idle + } + + // ====================== + // IMAGE VALIDATION + // ====================== + + /** + * Scan and validate images for training. + * + * PROCESS: + * 1. Face detection on all images + * 2. Duplicate checking + * 3. Validation against requirements (10+ images, one face per image) + * + * @param imageUris List of image URIs selected by user */ fun scanAndTagFaces(imageUris: List) { currentImageUris = imageUris @@ -53,7 +178,10 @@ class TrainViewModel @Inject constructor( } /** - * Replace a single image and re-scan + * Replace a single image and re-scan all images. + * + * @param oldUri Image to replace + * @param newUri New image */ fun replaceImage(oldUri: Uri, newUri: Uri) { viewModelScope.launch { @@ -74,7 +202,11 @@ class TrainViewModel @Inject constructor( } /** - * User manually selected a face from a multi-face image + * User manually selected a face from a multi-face image. + * + * @param imageUri Image with multiple faces + * @param faceIndex Which face the user selected (0-based) + * @param croppedFaceBitmap Cropped face bitmap */ fun selectFaceFromImage(imageUri: Uri, faceIndex: Int, croppedFaceBitmap: Bitmap) { manualFaceSelections[imageUri] = ManualFaceSelection(faceIndex, croppedFaceBitmap) @@ -88,7 +220,7 @@ class TrainViewModel @Inject constructor( } /** - * Perform the actual scanning + * Perform the actual scanning. */ private fun performScan(imageUris: List) { viewModelScope.launch { @@ -117,7 +249,7 @@ class TrainViewModel @Inject constructor( } /** - * Apply manual face selections to the results + * Apply manual face selections to the results. */ private fun applyManualSelections( result: TrainingSanityChecker.SanityCheckResult @@ -192,17 +324,18 @@ class TrainViewModel @Inject constructor( } /** - * Get formatted error messages + * Get formatted error messages. */ fun getFormattedErrors(result: TrainingSanityChecker.SanityCheckResult): List { return sanityChecker.formatValidationErrors(result.validationErrors) } /** - * Reset to idle state + * Reset to idle state. */ fun reset() { _uiState.value = ScanningState.Idle + _trainingState.value = TrainingState.Idle currentImageUris = emptyList() manualFaceSelections.clear() } @@ -211,10 +344,17 @@ class TrainViewModel @Inject constructor( super.onCleared() sanityChecker.cleanup() faceDetectionHelper.cleanup() + faceNetModel.close() } } -// Extension function to copy FaceDetectionResult with modifications +// ====================== +// EXTENSION FUNCTIONS +// ====================== + +/** + * Extension to copy FaceDetectionResult with modifications. + */ private fun FaceDetectionHelper.FaceDetectionResult.copy( uri: Uri = this.uri, hasFace: Boolean = this.hasFace, @@ -233,7 +373,9 @@ private fun FaceDetectionHelper.FaceDetectionResult.copy( ) } -// Extension function to copy SanityCheckResult with modifications +/** + * Extension to copy SanityCheckResult with modifications. + */ private fun TrainingSanityChecker.SanityCheckResult.copy( isValid: Boolean = this.isValid, faceDetectionResults: List = this.faceDetectionResults, diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 80bd766..9966887 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -19,10 +19,15 @@ room = "2.8.4" # Images coil = "2.7.0" -#Face Detect +# Face Detect mlkit-face-detection = "16.1.6" coroutines-play-services = "1.8.1" +# Models +tensorflow-lite = "2.14.0" +tensorflow-lite-support = "0.4.4" +gson = "2.10.1" + [libraries] androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" } androidx-lifecycle-runtime-ktx = { group = "androidx.lifecycle", name = "lifecycle-runtime-ktx", version.ref = "lifecycle" } @@ -56,10 +61,17 @@ coil-compose = { group = "io.coil-kt", name = "coil-compose", version.ref = "coi mlkit-face-detection = { group = "com.google.mlkit", name = "face-detection", version.ref = "mlkit-face-detection"} kotlinx-coroutines-play-services = {group = "org.jetbrains.kotlinx",name = "kotlinx-coroutines-play-services",version.ref = "coroutines-play-services"} +# TensorFlow Lite for FaceNet +tensorflow-lite = { group = "org.tensorflow", name = "tensorflow-lite", version.ref = "tensorflow-lite" } +tensorflow-lite-support = { group = "org.tensorflow", name = "tensorflow-lite-support", version.ref = "tensorflow-lite-support" } +tensorflow-lite-gpu = { group = "org.tensorflow", name = "tensorflow-lite-gpu", version.ref = "tensorflow-lite" } + +gson = { group = "com.google.code.gson", name = "gson", version.ref = "gson" } [plugins] android-application = { id = "com.android.application", version.ref = "agp" } kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" } kotlin-compose = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" } ksp = { id = "com.google.devtools.ksp", version.ref = "ksp" } -hilt-android = { id = "com.google.dagger.hilt.android", version.ref = "hilt" } \ No newline at end of file +hilt-android = { id = "com.google.dagger.hilt.android", version.ref = "hilt" } +