This commit is contained in:
genki
2026-01-25 22:01:46 -05:00
parent 941337f671
commit 1ef8faad17
12 changed files with 174 additions and 98 deletions

View File

@@ -4,7 +4,7 @@
<selectionStates> <selectionStates>
<SelectionState runConfigName="app"> <SelectionState runConfigName="app">
<option name="selectionMode" value="DROPDOWN" /> <option name="selectionMode" value="DROPDOWN" />
<DropdownSelection timestamp="2026-01-25T20:45:06.118763497Z"> <DropdownSelection timestamp="2026-01-26T02:23:12.309011764Z">
<Target type="DEFAULT_BOOT"> <Target type="DEFAULT_BOOT">
<handle> <handle>
<DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" /> <DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" />

View File

@@ -48,6 +48,9 @@ dependencies {
implementation(libs.androidx.lifecycle.viewmodel.compose) implementation(libs.androidx.lifecycle.viewmodel.compose)
implementation(libs.androidx.activity.compose) implementation(libs.androidx.activity.compose)
// DataStore Preferences
implementation("androidx.datastore:datastore-preferences:1.1.1")
// Compose // Compose
implementation(platform(libs.androidx.compose.bom)) implementation(platform(libs.androidx.compose.bom))
implementation(libs.androidx.compose.ui) implementation(libs.androidx.compose.ui)

View File

@@ -66,6 +66,9 @@ interface ImageDao {
@Query("SELECT * FROM images WHERE imageId = :imageId") @Query("SELECT * FROM images WHERE imageId = :imageId")
suspend fun getImageById(imageId: String): ImageEntity? suspend fun getImageById(imageId: String): ImageEntity?
@Query("SELECT * FROM images WHERE imageUri = :uri LIMIT 1")
suspend fun getImageByUri(uri: String): ImageEntity?
/** /**
* Stream images ordered by capture time (newest first). * Stream images ordered by capture time (newest first).
* *

View File

@@ -2,8 +2,10 @@ package com.placeholder.sherpai2.data.repository
import android.content.Context import android.content.Context
import android.graphics.Bitmap import android.graphics.Bitmap
import android.util.Log
import com.placeholder.sherpai2.data.local.dao.FaceModelDao import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.data.local.dao.ImageDao import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PersonAgeTagDao
import com.placeholder.sherpai2.data.local.dao.PersonDao import com.placeholder.sherpai2.data.local.dao.PersonDao
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
import com.placeholder.sherpai2.data.local.entity.* import com.placeholder.sherpai2.data.local.entity.*
@@ -31,8 +33,12 @@ class FaceRecognitionRepository @Inject constructor(
private val personDao: PersonDao, private val personDao: PersonDao,
private val imageDao: ImageDao, private val imageDao: ImageDao,
private val faceModelDao: FaceModelDao, private val faceModelDao: FaceModelDao,
private val photoFaceTagDao: PhotoFaceTagDao private val photoFaceTagDao: PhotoFaceTagDao,
private val personAgeTagDao: PersonAgeTagDao
) { ) {
companion object {
private const val TAG = "FaceRecognitionRepo"
}
private val faceNetModel by lazy { FaceNetModel(context) } private val faceNetModel by lazy { FaceNetModel(context) }
@@ -181,12 +187,15 @@ class FaceRecognitionRepository @Inject constructor(
var highestSimilarity = threshold var highestSimilarity = threshold
for (faceModel in faceModels) { for (faceModel in faceModels) {
val modelEmbedding = faceModel.getEmbeddingArray() // Check ALL centroids for best match (critical for children with age centroids)
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding) val centroids = faceModel.getCentroids()
val bestCentroidSimilarity = centroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid.getEmbeddingArray())
} ?: 0f
if (similarity > highestSimilarity) { if (bestCentroidSimilarity > highestSimilarity) {
highestSimilarity = similarity highestSimilarity = bestCentroidSimilarity
bestMatch = Pair(faceModel.id, similarity) bestMatch = Pair(faceModel.id, bestCentroidSimilarity)
} }
} }
@@ -374,9 +383,49 @@ class FaceRecognitionRepository @Inject constructor(
onProgress = onProgress onProgress = onProgress
) )
// Generate age tags for children
if (person.isChild && person.dateOfBirth != null) {
generateAgeTagsForTraining(person, validImages)
}
person.id person.id
} }
/**
* Generate age tags from training images for a child
*/
private suspend fun generateAgeTagsForTraining(
person: PersonEntity,
validImages: List<TrainingSanityChecker.ValidTrainingImage>
) {
try {
val dob = person.dateOfBirth ?: return
val tags = validImages.mapNotNull { img ->
val imageEntity = imageDao.getImageByUri(img.uri.toString()) ?: return@mapNotNull null
val ageMs = imageEntity.capturedAt - dob
val ageYears = (ageMs / (365.25 * 24 * 60 * 60 * 1000)).toInt()
if (ageYears < 0 || ageYears > 25) return@mapNotNull null
PersonAgeTagEntity.create(
personId = person.id,
personName = person.name,
imageId = imageEntity.imageId,
ageAtCapture = ageYears,
confidence = 1.0f
)
}
if (tags.isNotEmpty()) {
personAgeTagDao.insertTags(tags)
Log.d(TAG, "Created ${tags.size} age tags for ${person.name}")
}
} catch (e: Exception) {
Log.e(TAG, "Failed to generate age tags", e)
}
}
/** /**
* Get face model by ID * Get face model by ID
*/ */

View File

@@ -61,14 +61,16 @@ abstract class RepositoryModule {
personDao: PersonDao, personDao: PersonDao,
imageDao: ImageDao, imageDao: ImageDao,
faceModelDao: FaceModelDao, faceModelDao: FaceModelDao,
photoFaceTagDao: PhotoFaceTagDao photoFaceTagDao: PhotoFaceTagDao,
personAgeTagDao: PersonAgeTagDao
): FaceRecognitionRepository { ): FaceRecognitionRepository {
return FaceRecognitionRepository( return FaceRecognitionRepository(
context = context, context = context,
personDao = personDao, personDao = personDao,
imageDao = imageDao, imageDao = imageDao,
faceModelDao = faceModelDao, faceModelDao = faceModelDao,
photoFaceTagDao = photoFaceTagDao photoFaceTagDao = photoFaceTagDao,
personAgeTagDao = personAgeTagDao
) )
} }

View File

@@ -15,6 +15,7 @@ import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
import com.placeholder.sherpai2.data.local.entity.ImageEntity import com.placeholder.sherpai2.data.local.entity.ImageEntity
import com.placeholder.sherpai2.ml.FaceNetModel import com.placeholder.sherpai2.ml.FaceNetModel
import com.placeholder.sherpai2.ml.FaceNormalizer
import com.placeholder.sherpai2.ui.discover.DiscoverySettings import com.placeholder.sherpai2.ui.discover.DiscoverySettings
import dagger.hilt.android.qualifiers.ApplicationContext import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
@@ -344,14 +345,9 @@ class FaceClusteringService @Inject constructor(
} }
try { try {
// Crop and generate embedding // Crop and normalize face
val faceBitmap = Bitmap.createBitmap( val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, mlFace)
bitmap, ?: return@forEach
mlFace.boundingBox.left.coerceIn(0, bitmap.width - 1),
mlFace.boundingBox.top.coerceIn(0, bitmap.height - 1),
mlFace.boundingBox.width().coerceAtMost(bitmap.width - mlFace.boundingBox.left),
mlFace.boundingBox.height().coerceAtMost(bitmap.height - mlFace.boundingBox.top)
)
val embedding = faceNetModel.generateEmbedding(faceBitmap) val embedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle() faceBitmap.recycle()
@@ -591,13 +587,8 @@ class FaceClusteringService @Inject constructor(
if (!qualityCheck.isValid) return@mapNotNull null if (!qualityCheck.isValid) return@mapNotNull null
try { try {
val faceBitmap = Bitmap.createBitmap( val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, face)
bitmap, ?: return@mapNotNull null
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
)
val embedding = faceNetModel.generateEmbedding(faceBitmap) val embedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle() faceBitmap.recycle()

View File

@@ -339,10 +339,7 @@ fun AppNavHost(
* SETTINGS SCREEN * SETTINGS SCREEN
*/ */
composable(AppRoutes.SETTINGS) { composable(AppRoutes.SETTINGS) {
DummyScreen( com.placeholder.sherpai2.ui.settings.SettingsScreen()
title = "Settings",
subtitle = "App preferences and configuration"
)
} }
} }
} }

View File

@@ -78,6 +78,7 @@ fun MainScreen(
AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW! AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW!
AppRoutes.INVENTORY -> "People" AppRoutes.INVENTORY -> "People"
AppRoutes.TRAIN -> "Train Model" AppRoutes.TRAIN -> "Train Model"
AppRoutes.ScanResultsScreen -> "Train New Person"
AppRoutes.TAGS -> "Tags" AppRoutes.TAGS -> "Tags"
AppRoutes.UTILITIES -> "Utilities" AppRoutes.UTILITIES -> "Utilities"
AppRoutes.SETTINGS -> "Settings" AppRoutes.SETTINGS -> "Settings"

View File

@@ -6,8 +6,11 @@ import android.graphics.BitmapFactory
import android.graphics.Rect import android.graphics.Rect
import android.net.Uri import android.net.Uri
import com.google.mlkit.vision.common.InputImage import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.Face
import com.google.mlkit.vision.face.FaceDetection import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
import com.placeholder.sherpai2.ml.FaceNormalizer
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll import kotlinx.coroutines.awaitAll
@@ -64,21 +67,29 @@ class FaceDetectionHelper(private val context: Context) {
val inputImage = InputImage.fromBitmap(bitmap, 0) val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = detector.process(inputImage).await() val faces = detector.process(inputImage).await()
// Sort by face size (area) to get the largest face // Filter to quality faces only
val sortedFaces = faces.sortedByDescending { face -> val qualityFaces = faces.filter { face ->
FaceQualityFilter.validateForDiscovery(
face = face,
imageWidth = bitmap.width,
imageHeight = bitmap.height
).isValid
}
// Sort by face size (area) to get the largest quality face
val sortedFaces = qualityFaces.sortedByDescending { face ->
face.boundingBox.width() * face.boundingBox.height() face.boundingBox.width() * face.boundingBox.height()
} }
val croppedFace = if (sortedFaces.isNotEmpty()) { val croppedFace = if (sortedFaces.isNotEmpty()) {
// Crop the LARGEST detected face (most likely the subject) FaceNormalizer.cropAndNormalize(bitmap, sortedFaces[0])
cropFaceFromBitmap(bitmap, sortedFaces[0].boundingBox)
} else null } else null
FaceDetectionResult( FaceDetectionResult(
uri = uri, uri = uri,
hasFace = faces.isNotEmpty(), hasFace = qualityFaces.isNotEmpty(),
faceCount = faces.size, faceCount = qualityFaces.size,
faceBounds = faces.map { it.boundingBox }, faceBounds = qualityFaces.map { it.boundingBox },
croppedFaceBitmap = croppedFace croppedFaceBitmap = croppedFace
) )
} catch (e: Exception) { } catch (e: Exception) {

View File

@@ -51,21 +51,8 @@ fun ScanResultsScreen(
} }
} }
Scaffold( // No Scaffold - MainScreen provides TopAppBar
topBar = { Box(modifier = Modifier.fillMaxSize()) {
TopAppBar(
title = { Text("Train New Person") },
colors = TopAppBarDefaults.topAppBarColors(
containerColor = MaterialTheme.colorScheme.primaryContainer
)
)
}
) { paddingValues ->
Box(
modifier = Modifier
.fillMaxSize()
.padding(paddingValues)
) {
when (state) { when (state) {
is ScanningState.Idle -> {} is ScanningState.Idle -> {}
@@ -77,8 +64,6 @@ fun ScanResultsScreen(
ImprovedResultsView( ImprovedResultsView(
result = state.sanityCheckResult, result = state.sanityCheckResult,
onContinue = { onContinue = {
// PersonInfo already captured in TrainingScreen!
// Just start training with stored info
trainViewModel.createFaceModel( trainViewModel.createFaceModel(
trainViewModel.getPersonInfo()?.name ?: "Unknown" trainViewModel.getPersonInfo()?.name ?: "Unknown"
) )
@@ -103,7 +88,6 @@ fun ScanResultsScreen(
TrainingOverlay(trainingState = trainingState as TrainingState.Processing) TrainingOverlay(trainingState = trainingState as TrainingState.Processing)
} }
} }
}
showFacePickerDialog?.let { result -> showFacePickerDialog?.let { result ->
FacePickerDialog( FacePickerDialog(

View File

@@ -5,11 +5,18 @@ import android.graphics.Bitmap
import android.net.Uri import android.net.Uri
import androidx.lifecycle.AndroidViewModel import androidx.lifecycle.AndroidViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import androidx.datastore.preferences.core.booleanPreferencesKey
import androidx.datastore.preferences.preferencesDataStore
import androidx.work.WorkManager
import android.content.Context
import com.placeholder.sherpai2.data.local.entity.PersonEntity import com.placeholder.sherpai2.data.local.entity.PersonEntity
import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository
import com.placeholder.sherpai2.ml.FaceNetModel import com.placeholder.sherpai2.ml.FaceNetModel
import com.placeholder.sherpai2.workers.LibraryScanWorker
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@@ -48,15 +55,20 @@ data class PersonInfo(
/** /**
* FIXED TrainViewModel with proper exclude functionality and efficient replace * FIXED TrainViewModel with proper exclude functionality and efficient replace
*/ */
private val android.content.Context.dataStore by preferencesDataStore(name = "settings")
private val KEY_BACKGROUND_TAGGING = booleanPreferencesKey("background_recognition_tagging")
@HiltViewModel @HiltViewModel
class TrainViewModel @Inject constructor( class TrainViewModel @Inject constructor(
application: Application, application: Application,
private val faceRecognitionRepository: FaceRecognitionRepository, private val faceRecognitionRepository: FaceRecognitionRepository,
private val faceNetModel: FaceNetModel private val faceNetModel: FaceNetModel,
private val workManager: WorkManager
) : AndroidViewModel(application) { ) : AndroidViewModel(application) {
private val sanityChecker = TrainingSanityChecker(application) private val sanityChecker = TrainingSanityChecker(application)
private val faceDetectionHelper = FaceDetectionHelper(application) private val faceDetectionHelper = FaceDetectionHelper(application)
private val dataStore = application.dataStore
private val _uiState = MutableStateFlow<ScanningState>(ScanningState.Idle) private val _uiState = MutableStateFlow<ScanningState>(ScanningState.Idle)
val uiState: StateFlow<ScanningState> = _uiState.asStateFlow() val uiState: StateFlow<ScanningState> = _uiState.asStateFlow()
@@ -174,6 +186,20 @@ class TrainViewModel @Inject constructor(
relationship = person.relationship relationship = person.relationship
) )
// Trigger library scan if setting enabled
val backgroundTaggingEnabled = dataStore.data
.map { it[KEY_BACKGROUND_TAGGING] ?: true }
.first()
if (backgroundTaggingEnabled) {
val scanRequest = LibraryScanWorker.createWorkRequest(
personId = personId,
personName = personName,
threshold = 0.65f
)
workManager.enqueue(scanRequest)
}
} catch (e: Exception) { } catch (e: Exception) {
_trainingState.value = TrainingState.Error( _trainingState.value = TrainingState.Error(
e.message ?: "Failed to create face model" e.message ?: "Failed to create face model"

View File

@@ -9,6 +9,8 @@ import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.data.local.dao.FaceModelDao import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
import com.placeholder.sherpai2.ml.FaceNormalizer
import com.placeholder.sherpai2.data.local.dao.ImageDao import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
@@ -146,7 +148,12 @@ class LibraryScanWorker @AssistedInject constructor(
.build() .build()
) )
val modelEmbedding = faceModel.getEmbeddingArray() // Get ALL centroids for multi-centroid matching (critical for children)
val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() }
if (modelCentroids.isEmpty()) {
return@withContext Result.failure(workDataOf("error" to "No centroids in model"))
}
var matchesFound = 0 var matchesFound = 0
var photosScanned = 0 var photosScanned = 0
@@ -164,7 +171,7 @@ class LibraryScanWorker @AssistedInject constructor(
photo = photo, photo = photo,
personId = personId, personId = personId,
faceModelId = faceModel.id, faceModelId = faceModel.id,
modelEmbedding = modelEmbedding, modelCentroids = modelCentroids,
faceNetModel = faceNetModel, faceNetModel = faceNetModel,
detector = detector, detector = detector,
threshold = threshold threshold = threshold
@@ -228,7 +235,7 @@ class LibraryScanWorker @AssistedInject constructor(
photo: com.placeholder.sherpai2.data.local.entity.ImageEntity, photo: com.placeholder.sherpai2.data.local.entity.ImageEntity,
personId: String, personId: String,
faceModelId: String, faceModelId: String,
modelEmbedding: FloatArray, modelCentroids: List<FloatArray>,
faceNetModel: FaceNetModel, faceNetModel: FaceNetModel,
detector: com.google.mlkit.vision.face.FaceDetector, detector: com.google.mlkit.vision.face.FaceDetector,
threshold: Float threshold: Float
@@ -243,24 +250,26 @@ class LibraryScanWorker @AssistedInject constructor(
val inputImage = InputImage.fromBitmap(bitmap, 0) val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = detector.process(inputImage).await() val faces = detector.process(inputImage).await()
// Check each face // Check each face (filter by quality first)
val tags = faces.mapNotNull { face -> val tags = faces.mapNotNull { face ->
// Quality check
if (!FaceQualityFilter.validateForScanning(face, bitmap.width, bitmap.height)) {
return@mapNotNull null
}
try { try {
// Crop face // Crop and normalize face for best recognition
val faceBitmap = android.graphics.Bitmap.createBitmap( val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, face)
bitmap, ?: return@mapNotNull null
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
)
// Generate embedding // Generate embedding
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap) val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle() faceBitmap.recycle()
// Calculate similarity // Match against ALL centroids, use best match (critical for children)
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding) val similarity = modelCentroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f
if (similarity >= threshold) { if (similarity >= threshold) {
PhotoFaceTagEntity.create( PhotoFaceTagEntity.create(