This commit is contained in:
genki
2026-01-25 22:01:46 -05:00
parent 941337f671
commit 1ef8faad17
12 changed files with 174 additions and 98 deletions

View File

@@ -4,7 +4,7 @@
<selectionStates>
<SelectionState runConfigName="app">
<option name="selectionMode" value="DROPDOWN" />
<DropdownSelection timestamp="2026-01-25T20:45:06.118763497Z">
<DropdownSelection timestamp="2026-01-26T02:23:12.309011764Z">
<Target type="DEFAULT_BOOT">
<handle>
<DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" />

View File

@@ -48,6 +48,9 @@ dependencies {
implementation(libs.androidx.lifecycle.viewmodel.compose)
implementation(libs.androidx.activity.compose)
// DataStore Preferences
implementation("androidx.datastore:datastore-preferences:1.1.1")
// Compose
implementation(platform(libs.androidx.compose.bom))
implementation(libs.androidx.compose.ui)

View File

@@ -66,6 +66,9 @@ interface ImageDao {
@Query("SELECT * FROM images WHERE imageId = :imageId")
suspend fun getImageById(imageId: String): ImageEntity?
@Query("SELECT * FROM images WHERE imageUri = :uri LIMIT 1")
suspend fun getImageByUri(uri: String): ImageEntity?
/**
* Stream images ordered by capture time (newest first).
*

View File

@@ -2,8 +2,10 @@ package com.placeholder.sherpai2.data.repository
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PersonAgeTagDao
import com.placeholder.sherpai2.data.local.dao.PersonDao
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
import com.placeholder.sherpai2.data.local.entity.*
@@ -31,8 +33,12 @@ class FaceRecognitionRepository @Inject constructor(
private val personDao: PersonDao,
private val imageDao: ImageDao,
private val faceModelDao: FaceModelDao,
private val photoFaceTagDao: PhotoFaceTagDao
private val photoFaceTagDao: PhotoFaceTagDao,
private val personAgeTagDao: PersonAgeTagDao
) {
companion object {
private const val TAG = "FaceRecognitionRepo"
}
private val faceNetModel by lazy { FaceNetModel(context) }
@@ -181,12 +187,15 @@ class FaceRecognitionRepository @Inject constructor(
var highestSimilarity = threshold
for (faceModel in faceModels) {
val modelEmbedding = faceModel.getEmbeddingArray()
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
// Check ALL centroids for best match (critical for children with age centroids)
val centroids = faceModel.getCentroids()
val bestCentroidSimilarity = centroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid.getEmbeddingArray())
} ?: 0f
if (similarity > highestSimilarity) {
highestSimilarity = similarity
bestMatch = Pair(faceModel.id, similarity)
if (bestCentroidSimilarity > highestSimilarity) {
highestSimilarity = bestCentroidSimilarity
bestMatch = Pair(faceModel.id, bestCentroidSimilarity)
}
}
@@ -374,9 +383,49 @@ class FaceRecognitionRepository @Inject constructor(
onProgress = onProgress
)
// Generate age tags for children
if (person.isChild && person.dateOfBirth != null) {
generateAgeTagsForTraining(person, validImages)
}
person.id
}
/**
* Generate age tags from training images for a child
*/
private suspend fun generateAgeTagsForTraining(
person: PersonEntity,
validImages: List<TrainingSanityChecker.ValidTrainingImage>
) {
try {
val dob = person.dateOfBirth ?: return
val tags = validImages.mapNotNull { img ->
val imageEntity = imageDao.getImageByUri(img.uri.toString()) ?: return@mapNotNull null
val ageMs = imageEntity.capturedAt - dob
val ageYears = (ageMs / (365.25 * 24 * 60 * 60 * 1000)).toInt()
if (ageYears < 0 || ageYears > 25) return@mapNotNull null
PersonAgeTagEntity.create(
personId = person.id,
personName = person.name,
imageId = imageEntity.imageId,
ageAtCapture = ageYears,
confidence = 1.0f
)
}
if (tags.isNotEmpty()) {
personAgeTagDao.insertTags(tags)
Log.d(TAG, "Created ${tags.size} age tags for ${person.name}")
}
} catch (e: Exception) {
Log.e(TAG, "Failed to generate age tags", e)
}
}
/**
* Get face model by ID
*/

View File

@@ -61,14 +61,16 @@ abstract class RepositoryModule {
personDao: PersonDao,
imageDao: ImageDao,
faceModelDao: FaceModelDao,
photoFaceTagDao: PhotoFaceTagDao
photoFaceTagDao: PhotoFaceTagDao,
personAgeTagDao: PersonAgeTagDao
): FaceRecognitionRepository {
return FaceRecognitionRepository(
context = context,
personDao = personDao,
imageDao = imageDao,
faceModelDao = faceModelDao,
photoFaceTagDao = photoFaceTagDao
photoFaceTagDao = photoFaceTagDao,
personAgeTagDao = personAgeTagDao
)
}

View File

@@ -15,6 +15,7 @@ import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
import com.placeholder.sherpai2.data.local.entity.ImageEntity
import com.placeholder.sherpai2.ml.FaceNetModel
import com.placeholder.sherpai2.ml.FaceNormalizer
import com.placeholder.sherpai2.ui.discover.DiscoverySettings
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers
@@ -344,14 +345,9 @@ class FaceClusteringService @Inject constructor(
}
try {
// Crop and generate embedding
val faceBitmap = Bitmap.createBitmap(
bitmap,
mlFace.boundingBox.left.coerceIn(0, bitmap.width - 1),
mlFace.boundingBox.top.coerceIn(0, bitmap.height - 1),
mlFace.boundingBox.width().coerceAtMost(bitmap.width - mlFace.boundingBox.left),
mlFace.boundingBox.height().coerceAtMost(bitmap.height - mlFace.boundingBox.top)
)
// Crop and normalize face
val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, mlFace)
?: return@forEach
val embedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle()
@@ -591,13 +587,8 @@ class FaceClusteringService @Inject constructor(
if (!qualityCheck.isValid) return@mapNotNull null
try {
val faceBitmap = Bitmap.createBitmap(
bitmap,
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
)
val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, face)
?: return@mapNotNull null
val embedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle()

View File

@@ -339,10 +339,7 @@ fun AppNavHost(
* SETTINGS SCREEN
*/
composable(AppRoutes.SETTINGS) {
DummyScreen(
title = "Settings",
subtitle = "App preferences and configuration"
)
com.placeholder.sherpai2.ui.settings.SettingsScreen()
}
}
}

View File

@@ -78,6 +78,7 @@ fun MainScreen(
AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW!
AppRoutes.INVENTORY -> "People"
AppRoutes.TRAIN -> "Train Model"
AppRoutes.ScanResultsScreen -> "Train New Person"
AppRoutes.TAGS -> "Tags"
AppRoutes.UTILITIES -> "Utilities"
AppRoutes.SETTINGS -> "Settings"

View File

@@ -6,8 +6,11 @@ import android.graphics.BitmapFactory
import android.graphics.Rect
import android.net.Uri
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.Face
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
import com.placeholder.sherpai2.ml.FaceNormalizer
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
@@ -64,21 +67,29 @@ class FaceDetectionHelper(private val context: Context) {
val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = detector.process(inputImage).await()
// Sort by face size (area) to get the largest face
val sortedFaces = faces.sortedByDescending { face ->
// Filter to quality faces only
val qualityFaces = faces.filter { face ->
FaceQualityFilter.validateForDiscovery(
face = face,
imageWidth = bitmap.width,
imageHeight = bitmap.height
).isValid
}
// Sort by face size (area) to get the largest quality face
val sortedFaces = qualityFaces.sortedByDescending { face ->
face.boundingBox.width() * face.boundingBox.height()
}
val croppedFace = if (sortedFaces.isNotEmpty()) {
// Crop the LARGEST detected face (most likely the subject)
cropFaceFromBitmap(bitmap, sortedFaces[0].boundingBox)
FaceNormalizer.cropAndNormalize(bitmap, sortedFaces[0])
} else null
FaceDetectionResult(
uri = uri,
hasFace = faces.isNotEmpty(),
faceCount = faces.size,
faceBounds = faces.map { it.boundingBox },
hasFace = qualityFaces.isNotEmpty(),
faceCount = qualityFaces.size,
faceBounds = qualityFaces.map { it.boundingBox },
croppedFaceBitmap = croppedFace
)
} catch (e: Exception) {

View File

@@ -51,57 +51,41 @@ fun ScanResultsScreen(
}
}
Scaffold(
topBar = {
TopAppBar(
title = { Text("Train New Person") },
colors = TopAppBarDefaults.topAppBarColors(
containerColor = MaterialTheme.colorScheme.primaryContainer
// No Scaffold - MainScreen provides TopAppBar
Box(modifier = Modifier.fillMaxSize()) {
when (state) {
is ScanningState.Idle -> {}
is ScanningState.Processing -> {
ProcessingView(progress = state.progress, total = state.total)
}
is ScanningState.Success -> {
ImprovedResultsView(
result = state.sanityCheckResult,
onContinue = {
trainViewModel.createFaceModel(
trainViewModel.getPersonInfo()?.name ?: "Unknown"
)
},
onRetry = onFinish,
onReplaceImage = { oldUri, newUri ->
trainViewModel.replaceImage(oldUri, newUri)
},
onSelectFaceFromMultiple = { result ->
showFacePickerDialog = result
},
trainViewModel = trainViewModel
)
)
}
is ScanningState.Error -> {
ErrorView(message = state.message, onRetry = onFinish)
}
}
) { paddingValues ->
Box(
modifier = Modifier
.fillMaxSize()
.padding(paddingValues)
) {
when (state) {
is ScanningState.Idle -> {}
is ScanningState.Processing -> {
ProcessingView(progress = state.progress, total = state.total)
}
is ScanningState.Success -> {
ImprovedResultsView(
result = state.sanityCheckResult,
onContinue = {
// PersonInfo already captured in TrainingScreen!
// Just start training with stored info
trainViewModel.createFaceModel(
trainViewModel.getPersonInfo()?.name ?: "Unknown"
)
},
onRetry = onFinish,
onReplaceImage = { oldUri, newUri ->
trainViewModel.replaceImage(oldUri, newUri)
},
onSelectFaceFromMultiple = { result ->
showFacePickerDialog = result
},
trainViewModel = trainViewModel
)
}
is ScanningState.Error -> {
ErrorView(message = state.message, onRetry = onFinish)
}
}
if (trainingState is TrainingState.Processing) {
TrainingOverlay(trainingState = trainingState as TrainingState.Processing)
}
if (trainingState is TrainingState.Processing) {
TrainingOverlay(trainingState = trainingState as TrainingState.Processing)
}
}

View File

@@ -5,11 +5,18 @@ import android.graphics.Bitmap
import android.net.Uri
import androidx.lifecycle.AndroidViewModel
import androidx.lifecycle.viewModelScope
import androidx.datastore.preferences.core.booleanPreferencesKey
import androidx.datastore.preferences.preferencesDataStore
import androidx.work.WorkManager
import android.content.Context
import com.placeholder.sherpai2.data.local.entity.PersonEntity
import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository
import com.placeholder.sherpai2.ml.FaceNetModel
import com.placeholder.sherpai2.workers.LibraryScanWorker
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch
@@ -48,15 +55,20 @@ data class PersonInfo(
/**
* FIXED TrainViewModel with proper exclude functionality and efficient replace
*/
private val android.content.Context.dataStore by preferencesDataStore(name = "settings")
private val KEY_BACKGROUND_TAGGING = booleanPreferencesKey("background_recognition_tagging")
@HiltViewModel
class TrainViewModel @Inject constructor(
application: Application,
private val faceRecognitionRepository: FaceRecognitionRepository,
private val faceNetModel: FaceNetModel
private val faceNetModel: FaceNetModel,
private val workManager: WorkManager
) : AndroidViewModel(application) {
private val sanityChecker = TrainingSanityChecker(application)
private val faceDetectionHelper = FaceDetectionHelper(application)
private val dataStore = application.dataStore
private val _uiState = MutableStateFlow<ScanningState>(ScanningState.Idle)
val uiState: StateFlow<ScanningState> = _uiState.asStateFlow()
@@ -174,6 +186,20 @@ class TrainViewModel @Inject constructor(
relationship = person.relationship
)
// Trigger library scan if setting enabled
val backgroundTaggingEnabled = dataStore.data
.map { it[KEY_BACKGROUND_TAGGING] ?: true }
.first()
if (backgroundTaggingEnabled) {
val scanRequest = LibraryScanWorker.createWorkRequest(
personId = personId,
personName = personName,
threshold = 0.65f
)
workManager.enqueue(scanRequest)
}
} catch (e: Exception) {
_trainingState.value = TrainingState.Error(
e.message ?: "Failed to create face model"

View File

@@ -9,6 +9,8 @@ import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
import com.placeholder.sherpai2.ml.FaceNormalizer
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
@@ -146,7 +148,12 @@ class LibraryScanWorker @AssistedInject constructor(
.build()
)
val modelEmbedding = faceModel.getEmbeddingArray()
// Get ALL centroids for multi-centroid matching (critical for children)
val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() }
if (modelCentroids.isEmpty()) {
return@withContext Result.failure(workDataOf("error" to "No centroids in model"))
}
var matchesFound = 0
var photosScanned = 0
@@ -164,7 +171,7 @@ class LibraryScanWorker @AssistedInject constructor(
photo = photo,
personId = personId,
faceModelId = faceModel.id,
modelEmbedding = modelEmbedding,
modelCentroids = modelCentroids,
faceNetModel = faceNetModel,
detector = detector,
threshold = threshold
@@ -228,7 +235,7 @@ class LibraryScanWorker @AssistedInject constructor(
photo: com.placeholder.sherpai2.data.local.entity.ImageEntity,
personId: String,
faceModelId: String,
modelEmbedding: FloatArray,
modelCentroids: List<FloatArray>,
faceNetModel: FaceNetModel,
detector: com.google.mlkit.vision.face.FaceDetector,
threshold: Float
@@ -243,24 +250,26 @@ class LibraryScanWorker @AssistedInject constructor(
val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = detector.process(inputImage).await()
// Check each face
// Check each face (filter by quality first)
val tags = faces.mapNotNull { face ->
// Quality check
if (!FaceQualityFilter.validateForScanning(face, bitmap.width, bitmap.height)) {
return@mapNotNull null
}
try {
// Crop face
val faceBitmap = android.graphics.Bitmap.createBitmap(
bitmap,
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
)
// Crop and normalize face for best recognition
val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, face)
?: return@mapNotNull null
// Generate embedding
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle()
// Calculate similarity
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
// Match against ALL centroids, use best match (critical for children)
val similarity = modelCentroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f
if (similarity >= threshold) {
PhotoFaceTagEntity.create(