jFc
This commit is contained in:
2
.idea/deploymentTargetSelector.xml
generated
2
.idea/deploymentTargetSelector.xml
generated
@@ -4,7 +4,7 @@
|
|||||||
<selectionStates>
|
<selectionStates>
|
||||||
<SelectionState runConfigName="app">
|
<SelectionState runConfigName="app">
|
||||||
<option name="selectionMode" value="DROPDOWN" />
|
<option name="selectionMode" value="DROPDOWN" />
|
||||||
<DropdownSelection timestamp="2026-01-25T20:45:06.118763497Z">
|
<DropdownSelection timestamp="2026-01-26T02:23:12.309011764Z">
|
||||||
<Target type="DEFAULT_BOOT">
|
<Target type="DEFAULT_BOOT">
|
||||||
<handle>
|
<handle>
|
||||||
<DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" />
|
<DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" />
|
||||||
|
|||||||
@@ -48,6 +48,9 @@ dependencies {
|
|||||||
implementation(libs.androidx.lifecycle.viewmodel.compose)
|
implementation(libs.androidx.lifecycle.viewmodel.compose)
|
||||||
implementation(libs.androidx.activity.compose)
|
implementation(libs.androidx.activity.compose)
|
||||||
|
|
||||||
|
// DataStore Preferences
|
||||||
|
implementation("androidx.datastore:datastore-preferences:1.1.1")
|
||||||
|
|
||||||
// Compose
|
// Compose
|
||||||
implementation(platform(libs.androidx.compose.bom))
|
implementation(platform(libs.androidx.compose.bom))
|
||||||
implementation(libs.androidx.compose.ui)
|
implementation(libs.androidx.compose.ui)
|
||||||
|
|||||||
@@ -66,6 +66,9 @@ interface ImageDao {
|
|||||||
@Query("SELECT * FROM images WHERE imageId = :imageId")
|
@Query("SELECT * FROM images WHERE imageId = :imageId")
|
||||||
suspend fun getImageById(imageId: String): ImageEntity?
|
suspend fun getImageById(imageId: String): ImageEntity?
|
||||||
|
|
||||||
|
@Query("SELECT * FROM images WHERE imageUri = :uri LIMIT 1")
|
||||||
|
suspend fun getImageByUri(uri: String): ImageEntity?
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stream images ordered by capture time (newest first).
|
* Stream images ordered by capture time (newest first).
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package com.placeholder.sherpai2.data.repository
|
|||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
import android.graphics.Bitmap
|
import android.graphics.Bitmap
|
||||||
|
import android.util.Log
|
||||||
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
||||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.PersonAgeTagDao
|
||||||
import com.placeholder.sherpai2.data.local.dao.PersonDao
|
import com.placeholder.sherpai2.data.local.dao.PersonDao
|
||||||
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
|
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
|
||||||
import com.placeholder.sherpai2.data.local.entity.*
|
import com.placeholder.sherpai2.data.local.entity.*
|
||||||
@@ -31,8 +33,12 @@ class FaceRecognitionRepository @Inject constructor(
|
|||||||
private val personDao: PersonDao,
|
private val personDao: PersonDao,
|
||||||
private val imageDao: ImageDao,
|
private val imageDao: ImageDao,
|
||||||
private val faceModelDao: FaceModelDao,
|
private val faceModelDao: FaceModelDao,
|
||||||
private val photoFaceTagDao: PhotoFaceTagDao
|
private val photoFaceTagDao: PhotoFaceTagDao,
|
||||||
|
private val personAgeTagDao: PersonAgeTagDao
|
||||||
) {
|
) {
|
||||||
|
companion object {
|
||||||
|
private const val TAG = "FaceRecognitionRepo"
|
||||||
|
}
|
||||||
|
|
||||||
private val faceNetModel by lazy { FaceNetModel(context) }
|
private val faceNetModel by lazy { FaceNetModel(context) }
|
||||||
|
|
||||||
@@ -181,12 +187,15 @@ class FaceRecognitionRepository @Inject constructor(
|
|||||||
var highestSimilarity = threshold
|
var highestSimilarity = threshold
|
||||||
|
|
||||||
for (faceModel in faceModels) {
|
for (faceModel in faceModels) {
|
||||||
val modelEmbedding = faceModel.getEmbeddingArray()
|
// Check ALL centroids for best match (critical for children with age centroids)
|
||||||
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
|
val centroids = faceModel.getCentroids()
|
||||||
|
val bestCentroidSimilarity = centroids.maxOfOrNull { centroid ->
|
||||||
|
faceNetModel.calculateSimilarity(faceEmbedding, centroid.getEmbeddingArray())
|
||||||
|
} ?: 0f
|
||||||
|
|
||||||
if (similarity > highestSimilarity) {
|
if (bestCentroidSimilarity > highestSimilarity) {
|
||||||
highestSimilarity = similarity
|
highestSimilarity = bestCentroidSimilarity
|
||||||
bestMatch = Pair(faceModel.id, similarity)
|
bestMatch = Pair(faceModel.id, bestCentroidSimilarity)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -374,9 +383,49 @@ class FaceRecognitionRepository @Inject constructor(
|
|||||||
onProgress = onProgress
|
onProgress = onProgress
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Generate age tags for children
|
||||||
|
if (person.isChild && person.dateOfBirth != null) {
|
||||||
|
generateAgeTagsForTraining(person, validImages)
|
||||||
|
}
|
||||||
|
|
||||||
person.id
|
person.id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate age tags from training images for a child
|
||||||
|
*/
|
||||||
|
private suspend fun generateAgeTagsForTraining(
|
||||||
|
person: PersonEntity,
|
||||||
|
validImages: List<TrainingSanityChecker.ValidTrainingImage>
|
||||||
|
) {
|
||||||
|
try {
|
||||||
|
val dob = person.dateOfBirth ?: return
|
||||||
|
|
||||||
|
val tags = validImages.mapNotNull { img ->
|
||||||
|
val imageEntity = imageDao.getImageByUri(img.uri.toString()) ?: return@mapNotNull null
|
||||||
|
val ageMs = imageEntity.capturedAt - dob
|
||||||
|
val ageYears = (ageMs / (365.25 * 24 * 60 * 60 * 1000)).toInt()
|
||||||
|
|
||||||
|
if (ageYears < 0 || ageYears > 25) return@mapNotNull null
|
||||||
|
|
||||||
|
PersonAgeTagEntity.create(
|
||||||
|
personId = person.id,
|
||||||
|
personName = person.name,
|
||||||
|
imageId = imageEntity.imageId,
|
||||||
|
ageAtCapture = ageYears,
|
||||||
|
confidence = 1.0f
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tags.isNotEmpty()) {
|
||||||
|
personAgeTagDao.insertTags(tags)
|
||||||
|
Log.d(TAG, "Created ${tags.size} age tags for ${person.name}")
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Failed to generate age tags", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get face model by ID
|
* Get face model by ID
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -61,14 +61,16 @@ abstract class RepositoryModule {
|
|||||||
personDao: PersonDao,
|
personDao: PersonDao,
|
||||||
imageDao: ImageDao,
|
imageDao: ImageDao,
|
||||||
faceModelDao: FaceModelDao,
|
faceModelDao: FaceModelDao,
|
||||||
photoFaceTagDao: PhotoFaceTagDao
|
photoFaceTagDao: PhotoFaceTagDao,
|
||||||
|
personAgeTagDao: PersonAgeTagDao
|
||||||
): FaceRecognitionRepository {
|
): FaceRecognitionRepository {
|
||||||
return FaceRecognitionRepository(
|
return FaceRecognitionRepository(
|
||||||
context = context,
|
context = context,
|
||||||
personDao = personDao,
|
personDao = personDao,
|
||||||
imageDao = imageDao,
|
imageDao = imageDao,
|
||||||
faceModelDao = faceModelDao,
|
faceModelDao = faceModelDao,
|
||||||
photoFaceTagDao = photoFaceTagDao
|
photoFaceTagDao = photoFaceTagDao,
|
||||||
|
personAgeTagDao = personAgeTagDao
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import com.placeholder.sherpai2.data.local.dao.ImageDao
|
|||||||
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
|
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
|
||||||
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
||||||
import com.placeholder.sherpai2.ml.FaceNetModel
|
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||||
|
import com.placeholder.sherpai2.ml.FaceNormalizer
|
||||||
import com.placeholder.sherpai2.ui.discover.DiscoverySettings
|
import com.placeholder.sherpai2.ui.discover.DiscoverySettings
|
||||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
@@ -344,14 +345,9 @@ class FaceClusteringService @Inject constructor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Crop and generate embedding
|
// Crop and normalize face
|
||||||
val faceBitmap = Bitmap.createBitmap(
|
val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, mlFace)
|
||||||
bitmap,
|
?: return@forEach
|
||||||
mlFace.boundingBox.left.coerceIn(0, bitmap.width - 1),
|
|
||||||
mlFace.boundingBox.top.coerceIn(0, bitmap.height - 1),
|
|
||||||
mlFace.boundingBox.width().coerceAtMost(bitmap.width - mlFace.boundingBox.left),
|
|
||||||
mlFace.boundingBox.height().coerceAtMost(bitmap.height - mlFace.boundingBox.top)
|
|
||||||
)
|
|
||||||
|
|
||||||
val embedding = faceNetModel.generateEmbedding(faceBitmap)
|
val embedding = faceNetModel.generateEmbedding(faceBitmap)
|
||||||
faceBitmap.recycle()
|
faceBitmap.recycle()
|
||||||
@@ -591,13 +587,8 @@ class FaceClusteringService @Inject constructor(
|
|||||||
if (!qualityCheck.isValid) return@mapNotNull null
|
if (!qualityCheck.isValid) return@mapNotNull null
|
||||||
|
|
||||||
try {
|
try {
|
||||||
val faceBitmap = Bitmap.createBitmap(
|
val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, face)
|
||||||
bitmap,
|
?: return@mapNotNull null
|
||||||
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
|
|
||||||
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
|
|
||||||
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
|
|
||||||
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
|
|
||||||
)
|
|
||||||
|
|
||||||
val embedding = faceNetModel.generateEmbedding(faceBitmap)
|
val embedding = faceNetModel.generateEmbedding(faceBitmap)
|
||||||
faceBitmap.recycle()
|
faceBitmap.recycle()
|
||||||
|
|||||||
@@ -339,10 +339,7 @@ fun AppNavHost(
|
|||||||
* SETTINGS SCREEN
|
* SETTINGS SCREEN
|
||||||
*/
|
*/
|
||||||
composable(AppRoutes.SETTINGS) {
|
composable(AppRoutes.SETTINGS) {
|
||||||
DummyScreen(
|
com.placeholder.sherpai2.ui.settings.SettingsScreen()
|
||||||
title = "Settings",
|
|
||||||
subtitle = "App preferences and configuration"
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -78,6 +78,7 @@ fun MainScreen(
|
|||||||
AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW!
|
AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW!
|
||||||
AppRoutes.INVENTORY -> "People"
|
AppRoutes.INVENTORY -> "People"
|
||||||
AppRoutes.TRAIN -> "Train Model"
|
AppRoutes.TRAIN -> "Train Model"
|
||||||
|
AppRoutes.ScanResultsScreen -> "Train New Person"
|
||||||
AppRoutes.TAGS -> "Tags"
|
AppRoutes.TAGS -> "Tags"
|
||||||
AppRoutes.UTILITIES -> "Utilities"
|
AppRoutes.UTILITIES -> "Utilities"
|
||||||
AppRoutes.SETTINGS -> "Settings"
|
AppRoutes.SETTINGS -> "Settings"
|
||||||
|
|||||||
@@ -6,8 +6,11 @@ import android.graphics.BitmapFactory
|
|||||||
import android.graphics.Rect
|
import android.graphics.Rect
|
||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
import com.google.mlkit.vision.common.InputImage
|
import com.google.mlkit.vision.common.InputImage
|
||||||
|
import com.google.mlkit.vision.face.Face
|
||||||
import com.google.mlkit.vision.face.FaceDetection
|
import com.google.mlkit.vision.face.FaceDetection
|
||||||
import com.google.mlkit.vision.face.FaceDetectorOptions
|
import com.google.mlkit.vision.face.FaceDetectorOptions
|
||||||
|
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
|
||||||
|
import com.placeholder.sherpai2.ml.FaceNormalizer
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.async
|
import kotlinx.coroutines.async
|
||||||
import kotlinx.coroutines.awaitAll
|
import kotlinx.coroutines.awaitAll
|
||||||
@@ -64,21 +67,29 @@ class FaceDetectionHelper(private val context: Context) {
|
|||||||
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
||||||
val faces = detector.process(inputImage).await()
|
val faces = detector.process(inputImage).await()
|
||||||
|
|
||||||
// Sort by face size (area) to get the largest face
|
// Filter to quality faces only
|
||||||
val sortedFaces = faces.sortedByDescending { face ->
|
val qualityFaces = faces.filter { face ->
|
||||||
|
FaceQualityFilter.validateForDiscovery(
|
||||||
|
face = face,
|
||||||
|
imageWidth = bitmap.width,
|
||||||
|
imageHeight = bitmap.height
|
||||||
|
).isValid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by face size (area) to get the largest quality face
|
||||||
|
val sortedFaces = qualityFaces.sortedByDescending { face ->
|
||||||
face.boundingBox.width() * face.boundingBox.height()
|
face.boundingBox.width() * face.boundingBox.height()
|
||||||
}
|
}
|
||||||
|
|
||||||
val croppedFace = if (sortedFaces.isNotEmpty()) {
|
val croppedFace = if (sortedFaces.isNotEmpty()) {
|
||||||
// Crop the LARGEST detected face (most likely the subject)
|
FaceNormalizer.cropAndNormalize(bitmap, sortedFaces[0])
|
||||||
cropFaceFromBitmap(bitmap, sortedFaces[0].boundingBox)
|
|
||||||
} else null
|
} else null
|
||||||
|
|
||||||
FaceDetectionResult(
|
FaceDetectionResult(
|
||||||
uri = uri,
|
uri = uri,
|
||||||
hasFace = faces.isNotEmpty(),
|
hasFace = qualityFaces.isNotEmpty(),
|
||||||
faceCount = faces.size,
|
faceCount = qualityFaces.size,
|
||||||
faceBounds = faces.map { it.boundingBox },
|
faceBounds = qualityFaces.map { it.boundingBox },
|
||||||
croppedFaceBitmap = croppedFace
|
croppedFaceBitmap = croppedFace
|
||||||
)
|
)
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
|
|||||||
@@ -51,57 +51,41 @@ fun ScanResultsScreen(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Scaffold(
|
// No Scaffold - MainScreen provides TopAppBar
|
||||||
topBar = {
|
Box(modifier = Modifier.fillMaxSize()) {
|
||||||
TopAppBar(
|
when (state) {
|
||||||
title = { Text("Train New Person") },
|
is ScanningState.Idle -> {}
|
||||||
colors = TopAppBarDefaults.topAppBarColors(
|
|
||||||
containerColor = MaterialTheme.colorScheme.primaryContainer
|
is ScanningState.Processing -> {
|
||||||
|
ProcessingView(progress = state.progress, total = state.total)
|
||||||
|
}
|
||||||
|
|
||||||
|
is ScanningState.Success -> {
|
||||||
|
ImprovedResultsView(
|
||||||
|
result = state.sanityCheckResult,
|
||||||
|
onContinue = {
|
||||||
|
trainViewModel.createFaceModel(
|
||||||
|
trainViewModel.getPersonInfo()?.name ?: "Unknown"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
onRetry = onFinish,
|
||||||
|
onReplaceImage = { oldUri, newUri ->
|
||||||
|
trainViewModel.replaceImage(oldUri, newUri)
|
||||||
|
},
|
||||||
|
onSelectFaceFromMultiple = { result ->
|
||||||
|
showFacePickerDialog = result
|
||||||
|
},
|
||||||
|
trainViewModel = trainViewModel
|
||||||
)
|
)
|
||||||
)
|
}
|
||||||
|
|
||||||
|
is ScanningState.Error -> {
|
||||||
|
ErrorView(message = state.message, onRetry = onFinish)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
) { paddingValues ->
|
|
||||||
Box(
|
|
||||||
modifier = Modifier
|
|
||||||
.fillMaxSize()
|
|
||||||
.padding(paddingValues)
|
|
||||||
) {
|
|
||||||
when (state) {
|
|
||||||
is ScanningState.Idle -> {}
|
|
||||||
|
|
||||||
is ScanningState.Processing -> {
|
if (trainingState is TrainingState.Processing) {
|
||||||
ProcessingView(progress = state.progress, total = state.total)
|
TrainingOverlay(trainingState = trainingState as TrainingState.Processing)
|
||||||
}
|
|
||||||
|
|
||||||
is ScanningState.Success -> {
|
|
||||||
ImprovedResultsView(
|
|
||||||
result = state.sanityCheckResult,
|
|
||||||
onContinue = {
|
|
||||||
// PersonInfo already captured in TrainingScreen!
|
|
||||||
// Just start training with stored info
|
|
||||||
trainViewModel.createFaceModel(
|
|
||||||
trainViewModel.getPersonInfo()?.name ?: "Unknown"
|
|
||||||
)
|
|
||||||
},
|
|
||||||
onRetry = onFinish,
|
|
||||||
onReplaceImage = { oldUri, newUri ->
|
|
||||||
trainViewModel.replaceImage(oldUri, newUri)
|
|
||||||
},
|
|
||||||
onSelectFaceFromMultiple = { result ->
|
|
||||||
showFacePickerDialog = result
|
|
||||||
},
|
|
||||||
trainViewModel = trainViewModel
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
is ScanningState.Error -> {
|
|
||||||
ErrorView(message = state.message, onRetry = onFinish)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (trainingState is TrainingState.Processing) {
|
|
||||||
TrainingOverlay(trainingState = trainingState as TrainingState.Processing)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,11 +5,18 @@ import android.graphics.Bitmap
|
|||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
import androidx.lifecycle.AndroidViewModel
|
import androidx.lifecycle.AndroidViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
|
import androidx.datastore.preferences.core.booleanPreferencesKey
|
||||||
|
import androidx.datastore.preferences.preferencesDataStore
|
||||||
|
import androidx.work.WorkManager
|
||||||
|
import android.content.Context
|
||||||
import com.placeholder.sherpai2.data.local.entity.PersonEntity
|
import com.placeholder.sherpai2.data.local.entity.PersonEntity
|
||||||
import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository
|
import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository
|
||||||
import com.placeholder.sherpai2.ml.FaceNetModel
|
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||||
|
import com.placeholder.sherpai2.workers.LibraryScanWorker
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
|
import kotlinx.coroutines.flow.first
|
||||||
|
import kotlinx.coroutines.flow.map
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
import kotlinx.coroutines.flow.asStateFlow
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
@@ -48,15 +55,20 @@ data class PersonInfo(
|
|||||||
/**
|
/**
|
||||||
* FIXED TrainViewModel with proper exclude functionality and efficient replace
|
* FIXED TrainViewModel with proper exclude functionality and efficient replace
|
||||||
*/
|
*/
|
||||||
|
private val android.content.Context.dataStore by preferencesDataStore(name = "settings")
|
||||||
|
private val KEY_BACKGROUND_TAGGING = booleanPreferencesKey("background_recognition_tagging")
|
||||||
|
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class TrainViewModel @Inject constructor(
|
class TrainViewModel @Inject constructor(
|
||||||
application: Application,
|
application: Application,
|
||||||
private val faceRecognitionRepository: FaceRecognitionRepository,
|
private val faceRecognitionRepository: FaceRecognitionRepository,
|
||||||
private val faceNetModel: FaceNetModel
|
private val faceNetModel: FaceNetModel,
|
||||||
|
private val workManager: WorkManager
|
||||||
) : AndroidViewModel(application) {
|
) : AndroidViewModel(application) {
|
||||||
|
|
||||||
private val sanityChecker = TrainingSanityChecker(application)
|
private val sanityChecker = TrainingSanityChecker(application)
|
||||||
private val faceDetectionHelper = FaceDetectionHelper(application)
|
private val faceDetectionHelper = FaceDetectionHelper(application)
|
||||||
|
private val dataStore = application.dataStore
|
||||||
|
|
||||||
private val _uiState = MutableStateFlow<ScanningState>(ScanningState.Idle)
|
private val _uiState = MutableStateFlow<ScanningState>(ScanningState.Idle)
|
||||||
val uiState: StateFlow<ScanningState> = _uiState.asStateFlow()
|
val uiState: StateFlow<ScanningState> = _uiState.asStateFlow()
|
||||||
@@ -174,6 +186,20 @@ class TrainViewModel @Inject constructor(
|
|||||||
relationship = person.relationship
|
relationship = person.relationship
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Trigger library scan if setting enabled
|
||||||
|
val backgroundTaggingEnabled = dataStore.data
|
||||||
|
.map { it[KEY_BACKGROUND_TAGGING] ?: true }
|
||||||
|
.first()
|
||||||
|
|
||||||
|
if (backgroundTaggingEnabled) {
|
||||||
|
val scanRequest = LibraryScanWorker.createWorkRequest(
|
||||||
|
personId = personId,
|
||||||
|
personName = personName,
|
||||||
|
threshold = 0.65f
|
||||||
|
)
|
||||||
|
workManager.enqueue(scanRequest)
|
||||||
|
}
|
||||||
|
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
_trainingState.value = TrainingState.Error(
|
_trainingState.value = TrainingState.Error(
|
||||||
e.message ?: "Failed to create face model"
|
e.message ?: "Failed to create face model"
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import com.google.mlkit.vision.common.InputImage
|
|||||||
import com.google.mlkit.vision.face.FaceDetection
|
import com.google.mlkit.vision.face.FaceDetection
|
||||||
import com.google.mlkit.vision.face.FaceDetectorOptions
|
import com.google.mlkit.vision.face.FaceDetectorOptions
|
||||||
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
||||||
|
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
|
||||||
|
import com.placeholder.sherpai2.ml.FaceNormalizer
|
||||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||||
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
|
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
|
||||||
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
|
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
|
||||||
@@ -146,7 +148,12 @@ class LibraryScanWorker @AssistedInject constructor(
|
|||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
|
||||||
val modelEmbedding = faceModel.getEmbeddingArray()
|
// Get ALL centroids for multi-centroid matching (critical for children)
|
||||||
|
val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() }
|
||||||
|
if (modelCentroids.isEmpty()) {
|
||||||
|
return@withContext Result.failure(workDataOf("error" to "No centroids in model"))
|
||||||
|
}
|
||||||
|
|
||||||
var matchesFound = 0
|
var matchesFound = 0
|
||||||
var photosScanned = 0
|
var photosScanned = 0
|
||||||
|
|
||||||
@@ -164,7 +171,7 @@ class LibraryScanWorker @AssistedInject constructor(
|
|||||||
photo = photo,
|
photo = photo,
|
||||||
personId = personId,
|
personId = personId,
|
||||||
faceModelId = faceModel.id,
|
faceModelId = faceModel.id,
|
||||||
modelEmbedding = modelEmbedding,
|
modelCentroids = modelCentroids,
|
||||||
faceNetModel = faceNetModel,
|
faceNetModel = faceNetModel,
|
||||||
detector = detector,
|
detector = detector,
|
||||||
threshold = threshold
|
threshold = threshold
|
||||||
@@ -228,7 +235,7 @@ class LibraryScanWorker @AssistedInject constructor(
|
|||||||
photo: com.placeholder.sherpai2.data.local.entity.ImageEntity,
|
photo: com.placeholder.sherpai2.data.local.entity.ImageEntity,
|
||||||
personId: String,
|
personId: String,
|
||||||
faceModelId: String,
|
faceModelId: String,
|
||||||
modelEmbedding: FloatArray,
|
modelCentroids: List<FloatArray>,
|
||||||
faceNetModel: FaceNetModel,
|
faceNetModel: FaceNetModel,
|
||||||
detector: com.google.mlkit.vision.face.FaceDetector,
|
detector: com.google.mlkit.vision.face.FaceDetector,
|
||||||
threshold: Float
|
threshold: Float
|
||||||
@@ -243,24 +250,26 @@ class LibraryScanWorker @AssistedInject constructor(
|
|||||||
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
||||||
val faces = detector.process(inputImage).await()
|
val faces = detector.process(inputImage).await()
|
||||||
|
|
||||||
// Check each face
|
// Check each face (filter by quality first)
|
||||||
val tags = faces.mapNotNull { face ->
|
val tags = faces.mapNotNull { face ->
|
||||||
|
// Quality check
|
||||||
|
if (!FaceQualityFilter.validateForScanning(face, bitmap.width, bitmap.height)) {
|
||||||
|
return@mapNotNull null
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Crop face
|
// Crop and normalize face for best recognition
|
||||||
val faceBitmap = android.graphics.Bitmap.createBitmap(
|
val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, face)
|
||||||
bitmap,
|
?: return@mapNotNull null
|
||||||
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
|
|
||||||
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
|
|
||||||
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
|
|
||||||
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
|
|
||||||
)
|
|
||||||
|
|
||||||
// Generate embedding
|
// Generate embedding
|
||||||
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
|
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
|
||||||
faceBitmap.recycle()
|
faceBitmap.recycle()
|
||||||
|
|
||||||
// Calculate similarity
|
// Match against ALL centroids, use best match (critical for children)
|
||||||
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
|
val similarity = modelCentroids.maxOfOrNull { centroid ->
|
||||||
|
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
|
||||||
|
} ?: 0f
|
||||||
|
|
||||||
if (similarity >= threshold) {
|
if (similarity >= threshold) {
|
||||||
PhotoFaceTagEntity.create(
|
PhotoFaceTagEntity.create(
|
||||||
|
|||||||
Reference in New Issue
Block a user