Not quite happy

Improving scanning logic / flow
This commit is contained in:
genki
2026-01-14 07:58:21 -05:00
parent 393e5ecede
commit bf0bdfbd2e
5 changed files with 839 additions and 234 deletions

View File

@@ -376,6 +376,14 @@ class FaceRecognitionRepository @Inject constructor(
photoFaceTagDao.deleteTagsForImage(imageId)
}
/**
* Get all image IDs that have been tagged with this face model
* Used for scan optimization (skip already-tagged images)
*/
suspend fun getImageIdsForFaceModel(faceModelId: String): List<String> = withContext(Dispatchers.IO) {
photoFaceTagDao.getImageIdsForFaceModel(faceModelId)
}
fun cleanup() {
faceNetModel.close()
}
@@ -397,4 +405,3 @@ data class PersonFaceStats(
val averageConfidence: Float,
val lastDetectedAt: Long?
)

View File

@@ -1,5 +1,9 @@
package com.placeholder.sherpai2.ui.modelinventory
import android.net.Uri
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.Image
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.LazyColumn
@@ -14,9 +18,12 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel
import com.placeholder.sherpai2.ui.trainingprep.TrainingSanityChecker
/**
* CLEANED PersonInventoryScreen - No duplicate header
@@ -39,6 +46,7 @@ fun PersonInventoryScreen(
) {
val uiState by viewModel.uiState.collectAsState()
val scanningState by viewModel.scanningState.collectAsState()
val improvementState by viewModel.improvementState.collectAsState()
var personToDelete by remember { mutableStateOf<PersonInventoryViewModel.PersonWithStats?>(null) }
var personToScan by remember { mutableStateOf<PersonInventoryViewModel.PersonWithStats?>(null) }
@@ -90,7 +98,10 @@ fun PersonInventoryScreen(
person = person,
onDelete = { personToDelete = person },
onScan = { personToScan = person },
onViewPhotos = { onViewPersonPhotos(person.person.id) }
onViewPhotos = { onViewPersonPhotos(person.person.id) },
onImproveModel = {
viewModel.startModelImprovement(person.person.id, person.stats.faceModelId)
}
)
}
}
@@ -127,6 +138,12 @@ fun PersonInventoryScreen(
}
)
}
// Model improvement dialogs
HandleModelImprovementState(
improvementState = improvementState,
viewModel = viewModel
)
}
/**
@@ -197,7 +214,8 @@ private fun PersonCard(
person: PersonInventoryViewModel.PersonWithStats,
onDelete: () -> Unit,
onScan: () -> Unit,
onViewPhotos: () -> Unit
onViewPhotos: () -> Unit,
onImproveModel: () -> Unit
) {
Card(
modifier = Modifier.fillMaxWidth(),
@@ -263,6 +281,22 @@ private fun PersonCard(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.spacedBy(8.dp)
) {
OutlinedButton(
onClick = onImproveModel,
modifier = Modifier.weight(1f),
colors = ButtonDefaults.outlinedButtonColors(
contentColor = MaterialTheme.colorScheme.tertiary
)
) {
Icon(
Icons.Default.TrendingUp,
contentDescription = null,
modifier = Modifier.size(18.dp)
)
Spacer(Modifier.width(4.dp))
Text("Improve")
}
OutlinedButton(
onClick = onScan,
modifier = Modifier.weight(1f)
@@ -506,3 +540,378 @@ private fun ScanDialog(
}
)
}
/**
* Handle all model improvement dialog states
*/
@Composable
private fun HandleModelImprovementState(
improvementState: PersonInventoryViewModel.ModelImprovementState,
viewModel: PersonInventoryViewModel
) {
when (improvementState) {
is PersonInventoryViewModel.ModelImprovementState.SelectingPhotos -> {
val launcher = rememberLauncherForActivityResult(
contract = ActivityResultContracts.GetMultipleContents()
) { uris ->
if (uris.isNotEmpty()) {
viewModel.processSelectedPhotos(
personId = improvementState.personId,
faceModelId = improvementState.faceModelId,
selectedImageUris = uris
)
} else {
viewModel.cancelModelImprovement()
}
}
LaunchedEffect(Unit) {
launcher.launch("image/*")
}
AlertDialog(
onDismissRequest = { viewModel.cancelModelImprovement() },
icon = { Icon(Icons.Default.TrendingUp, contentDescription = null) },
title = { Text("Improve ${improvementState.personName}'s Model") },
text = {
Column(verticalArrangement = Arrangement.spacedBy(12.dp)) {
Text("Add 5-15 photos to improve accuracy")
Card(
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.tertiaryContainer.copy(alpha = 0.3f)
)
) {
Column(
modifier = Modifier.padding(12.dp),
verticalArrangement = Arrangement.spacedBy(8.dp)
) {
Text(
"Current: ${improvementState.currentTrainingCount} photos",
style = MaterialTheme.typography.labelMedium,
fontWeight = FontWeight.Bold
)
}
}
}
},
confirmButton = {},
dismissButton = {
TextButton(onClick = { viewModel.cancelModelImprovement() }) {
Text("Cancel")
}
}
)
}
is PersonInventoryViewModel.ModelImprovementState.ValidatingPhotos -> {
AlertDialog(
onDismissRequest = {},
title = { Text("Validating Photos") },
text = {
Column(verticalArrangement = Arrangement.spacedBy(16.dp)) {
LinearProgressIndicator(
progress = {
if (improvementState.total > 0) {
improvementState.current.toFloat() / improvementState.total
} else 0f
},
modifier = Modifier.fillMaxWidth()
)
Text(improvementState.progress)
Text(
"${improvementState.current} / ${improvementState.total}",
style = MaterialTheme.typography.bodySmall
)
}
},
confirmButton = {}
)
}
is PersonInventoryViewModel.ModelImprovementState.ReviewingPhotos -> {
ReviewPhotosDialog(
state = improvementState,
onConfirm = {
viewModel.retrainModelWithValidatedPhotos(
personId = improvementState.personId,
faceModelId = improvementState.faceModelId,
sanityCheckResult = improvementState.sanityCheckResult
)
},
onDismiss = { viewModel.cancelModelImprovement() }
)
}
is PersonInventoryViewModel.ModelImprovementState.Training -> {
AlertDialog(
onDismissRequest = {},
title = { Text("Training Model") },
text = {
Column(verticalArrangement = Arrangement.spacedBy(16.dp)) {
LinearProgressIndicator(
progress = {
if (improvementState.total > 0) {
improvementState.progress.toFloat() / improvementState.total
} else 0f
},
modifier = Modifier.fillMaxWidth()
)
Text(improvementState.currentPhase)
}
},
confirmButton = {}
)
}
is PersonInventoryViewModel.ModelImprovementState.TrainingComplete -> {
AlertDialog(
onDismissRequest = { viewModel.cancelModelImprovement() },
icon = {
Icon(
Icons.Default.CheckCircle,
contentDescription = null,
tint = MaterialTheme.colorScheme.primary
)
},
title = { Text("Model Improved!") },
text = {
Column(verticalArrangement = Arrangement.spacedBy(12.dp)) {
Text("Successfully improved ${improvementState.personName}'s model")
Card(
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f)
)
) {
Column(
modifier = Modifier.padding(12.dp),
verticalArrangement = Arrangement.spacedBy(8.dp)
) {
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.SpaceBetween
) {
Text("Photos added:", style = MaterialTheme.typography.bodySmall)
Text("${improvementState.photosAdded}", fontWeight = FontWeight.Bold)
}
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.SpaceBetween
) {
Text("New count:", style = MaterialTheme.typography.bodySmall)
Text("${improvementState.newTrainingCount}", fontWeight = FontWeight.Bold)
}
HorizontalDivider()
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.SpaceBetween
) {
Text("${String.format("%.1f", improvementState.oldConfidence * 100)}%")
Icon(Icons.Default.ArrowForward, contentDescription = null, modifier = Modifier.size(16.dp))
Text(
"${String.format("%.1f", improvementState.newConfidence * 100)}%",
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.primary
)
}
}
}
}
},
confirmButton = {
Button(onClick = { viewModel.cancelModelImprovement() }) {
Text("Done")
}
}
)
}
is PersonInventoryViewModel.ModelImprovementState.Error -> {
AlertDialog(
onDismissRequest = { viewModel.cancelModelImprovement() },
icon = { Icon(Icons.Default.Error, contentDescription = null) },
title = { Text("Error") },
text = { Text(improvementState.message) },
confirmButton = {
TextButton(onClick = { viewModel.cancelModelImprovement() }) {
Text("OK")
}
}
)
}
PersonInventoryViewModel.ModelImprovementState.Idle -> {}
}
}
/**
* Review photos dialog with validation results
*/
@Composable
private fun ReviewPhotosDialog(
state: PersonInventoryViewModel.ModelImprovementState.ReviewingPhotos,
onConfirm: () -> Unit,
onDismiss: () -> Unit
) {
val validImages = state.sanityCheckResult.validImagesWithFaces
val hasErrors = state.sanityCheckResult.validationErrors.isNotEmpty()
AlertDialog(
onDismissRequest = onDismiss,
title = { Text("Review Photos") },
text = {
LazyColumn(
modifier = Modifier.height(400.dp),
verticalArrangement = Arrangement.spacedBy(12.dp)
) {
item {
Card(
colors = CardDefaults.cardColors(
containerColor = if (!hasErrors) {
MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f)
} else {
MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f)
}
)
) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(12.dp),
horizontalArrangement = Arrangement.SpaceBetween,
verticalAlignment = Alignment.CenterVertically
) {
Column {
Text(
"${validImages.size} valid photos",
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.Bold,
color = if (!hasErrors) {
MaterialTheme.colorScheme.primary
} else {
MaterialTheme.colorScheme.error
}
)
if (hasErrors) {
Text(
"${state.sanityCheckResult.validationErrors.size} issues",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.error
)
}
}
Text(
"${state.currentTrainingCount + validImages.size}",
style = MaterialTheme.typography.headlineMedium,
fontWeight = FontWeight.Bold
)
}
}
}
if (validImages.isNotEmpty()) {
item {
Text(
"Valid Photos",
style = MaterialTheme.typography.labelLarge,
fontWeight = FontWeight.Bold
)
}
items(validImages) { img ->
Card(modifier = Modifier.fillMaxWidth()) {
Row(
modifier = Modifier.padding(8.dp),
horizontalArrangement = Arrangement.spacedBy(12.dp),
verticalAlignment = Alignment.CenterVertically
) {
Image(
bitmap = img.croppedFaceBitmap.asImageBitmap(),
contentDescription = null,
modifier = Modifier
.size(64.dp)
.clip(RoundedCornerShape(8.dp)),
contentScale = ContentScale.Crop
)
Column {
Row(horizontalArrangement = Arrangement.spacedBy(4.dp)) {
Icon(
Icons.Default.CheckCircle,
contentDescription = null,
modifier = Modifier.size(16.dp),
tint = MaterialTheme.colorScheme.primary
)
Text(
"Valid",
style = MaterialTheme.typography.labelMedium,
fontWeight = FontWeight.Bold
)
}
Text(
"${img.faceCount} face(s)",
style = MaterialTheme.typography.bodySmall
)
}
}
}
}
}
if (hasErrors) {
item {
Text(
"Issues",
style = MaterialTheme.typography.labelLarge,
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.error
)
}
items(state.sanityCheckResult.validationErrors) { error ->
Card(
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f)
)
) {
Row(
modifier = Modifier.padding(12.dp),
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
Icon(
Icons.Default.Warning,
contentDescription = null,
tint = MaterialTheme.colorScheme.error
)
Text(
when (error) {
is TrainingSanityChecker.ValidationError.NoFaceDetected ->
"${error.uris.size} without faces"
is TrainingSanityChecker.ValidationError.MultipleFacesDetected ->
"Multiple faces"
is TrainingSanityChecker.ValidationError.DuplicateImages ->
"Duplicates"
is TrainingSanityChecker.ValidationError.InsufficientImages ->
"Need ${error.required}"
is TrainingSanityChecker.ValidationError.ImageLoadError ->
"Load failed"
},
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.error
)
}
}
}
}
}
},
confirmButton = {
Button(
onClick = onConfirm,
enabled = validImages.isNotEmpty()
) {
Text("Train (${validImages.size})")
}
},
dismissButton = {
TextButton(onClick = onDismiss) {
Text("Cancel")
}
}
)
}

View File

@@ -10,6 +10,7 @@ import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.data.local.entity.PersonEntity
import com.placeholder.sherpai2.data.local.entity.FaceModelEntity
import com.placeholder.sherpai2.data.repository.DetectedFace
import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository
import com.placeholder.sherpai2.data.repository.PersonFaceStats
@@ -17,26 +18,30 @@ import com.placeholder.sherpai2.domain.repository.ImageRepository
import com.placeholder.sherpai2.ml.ThresholdStrategy
import com.placeholder.sherpai2.ml.ImageQuality
import com.placeholder.sherpai2.ml.DetectionContext
import com.placeholder.sherpai2.ui.trainingprep.TrainingSanityChecker
import com.placeholder.sherpai2.ui.trainingprep.FaceDetectionHelper
import com.placeholder.sherpai2.util.DebugFlags
import com.placeholder.sherpai2.util.DiagnosticLogger
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.sync.withPermit
import kotlinx.coroutines.withContext
import kotlinx.coroutines.tasks.await
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger
import javax.inject.Inject
/**
* PersonInventoryViewModel - Enhanced with smart threshold strategy
*
* Toggle diagnostics in DebugFlags.kt:
* - ENABLE_FACE_RECOGNITION_LOGGING = true/false
* - USE_LIBERAL_THRESHOLDS = true/false
* PersonInventoryViewModel with optimized scanning and model improvement
*/
@HiltViewModel
class PersonInventoryViewModel @Inject constructor(
@@ -51,6 +56,13 @@ class PersonInventoryViewModel @Inject constructor(
private val _scanningState = MutableStateFlow<ScanningState>(ScanningState.Idle)
val scanningState: StateFlow<ScanningState> = _scanningState.asStateFlow()
private val _improvementState = MutableStateFlow<ModelImprovementState>(ModelImprovementState.Idle)
val improvementState: StateFlow<ModelImprovementState> = _improvementState.asStateFlow()
private val faceDetectionHelper = FaceDetectionHelper(application)
private val sanityChecker = TrainingSanityChecker(application)
private val faceDetectionCache = ConcurrentHashMap<String, List<DetectedFace>>()
private val faceDetector by lazy {
val options = FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
@@ -61,6 +73,12 @@ class PersonInventoryViewModel @Inject constructor(
FaceDetection.getClient(options)
}
companion object {
private const val PARALLEL_IMAGE_PROCESSING = 4
private const val BATCH_SIZE = 20
private const val PROGRESS_UPDATE_INTERVAL_MS = 100L
}
data class PersonWithStats(
val person: PersonEntity,
val stats: PersonFaceStats
@@ -80,16 +98,57 @@ class PersonInventoryViewModel @Inject constructor(
val progress: Int,
val total: Int,
val facesFound: Int,
val facesDetected: Int = 0
val facesDetected: Int = 0,
val imagesSkipped: Int = 0
) : ScanningState()
data class Complete(
val personName: String,
val facesFound: Int,
val imagesScanned: Int,
val totalFacesDetected: Int = 0
val totalFacesDetected: Int = 0,
val imagesSkipped: Int = 0
) : ScanningState()
}
sealed class ModelImprovementState {
object Idle : ModelImprovementState()
data class SelectingPhotos(
val personId: String,
val personName: String,
val faceModelId: String,
val currentTrainingCount: Int
) : ModelImprovementState()
data class ValidatingPhotos(
val personId: String,
val personName: String,
val faceModelId: String,
val progress: String,
val current: Int,
val total: Int
) : ModelImprovementState()
data class ReviewingPhotos(
val personId: String,
val personName: String,
val faceModelId: String,
val sanityCheckResult: TrainingSanityChecker.SanityCheckResult,
val currentTrainingCount: Int
) : ModelImprovementState()
data class Training(
val personName: String,
val progress: Int,
val total: Int,
val currentPhase: String
) : ModelImprovementState()
data class TrainingComplete(
val personName: String,
val photosAdded: Int,
val newTrainingCount: Int,
val oldConfidence: Float,
val newConfidence: Float
) : ModelImprovementState()
data class Error(val message: String) : ModelImprovementState()
}
init {
loadPersons()
}
@@ -98,24 +157,14 @@ class PersonInventoryViewModel @Inject constructor(
viewModelScope.launch {
try {
_uiState.value = InventoryUiState.Loading
val persons = faceRecognitionRepository.getPersonsWithFaceModels()
val personsWithStats = persons.mapNotNull { person ->
val stats = faceRecognitionRepository.getPersonFaceStats(person.id)
if (stats != null) {
PersonWithStats(person, stats)
} else {
null
}
if (stats != null) PersonWithStats(person, stats) else null
}.sortedByDescending { it.stats.taggedPhotoCount }
_uiState.value = InventoryUiState.Success(personsWithStats)
} catch (e: Exception) {
_uiState.value = InventoryUiState.Error(
e.message ?: "Failed to load persons"
)
_uiState.value = InventoryUiState.Error(e.message ?: "Failed to load persons")
}
}
}
@@ -124,138 +173,91 @@ class PersonInventoryViewModel @Inject constructor(
viewModelScope.launch {
try {
faceRecognitionRepository.deleteFaceModel(faceModelId)
faceDetectionCache.clear()
loadPersons()
} catch (e: Exception) {
_uiState.value = InventoryUiState.Error(
"Failed to delete: ${e.message}"
)
_uiState.value = InventoryUiState.Error("Failed to delete: ${e.message}")
}
}
}
/**
* Scan library with SMART threshold selection
*/
fun scanLibraryForPerson(personId: String, faceModelId: String) {
viewModelScope.launch {
val startTime = System.currentTimeMillis()
try {
if (DebugFlags.ENABLE_FACE_RECOGNITION_LOGGING) {
DiagnosticLogger.i("=== STARTING LIBRARY SCAN (ENHANCED) ===")
DiagnosticLogger.i("PersonId: $personId")
DiagnosticLogger.i("FaceModelId: $faceModelId")
DiagnosticLogger.i("=== OPTIMIZED SCAN START ===")
}
val currentState = _uiState.value
val person = if (currentState is InventoryUiState.Success) {
currentState.persons.find { it.person.id == personId }?.person
} else null
val personName = person?.name ?: "Unknown"
// Get face model to determine training count
val faceModel = faceRecognitionRepository.getFaceModelById(faceModelId)
val trainingCount = faceModel?.trainingImageCount ?: 15
?: throw IllegalStateException("Face model not found")
val trainingCount = faceModel.trainingImageCount
DiagnosticLogger.i("Training count: $trainingCount")
val alreadyTaggedImageIds = faceRecognitionRepository
.getImageIdsForFaceModel(faceModelId).toSet()
val allImages = imageRepository.getAllImages().first()
val totalImages = allImages.size
DiagnosticLogger.i("Total images in library: $totalImages")
val processedCount = AtomicInteger(0)
val facesFoundCount = AtomicInteger(0)
val totalFacesDetectedCount = AtomicInteger(0)
val skippedCount = AtomicInteger(0)
_scanningState.value = ScanningState.Scanning(
personId = personId,
personName = personName,
progress = 0,
total = totalImages,
facesFound = 0,
facesDetected = 0
personId, personName, 0, totalImages, 0, 0, 0
)
var facesFound = 0
var totalFacesDetected = 0
val semaphore = Semaphore(PARALLEL_IMAGE_PROCESSING)
var lastProgressUpdate = 0L
allImages.forEachIndexed { index, imageWithEverything ->
val image = imageWithEverything.image
DiagnosticLogger.d("--- Image ${index + 1}/$totalImages ---")
DiagnosticLogger.d("ImageId: ${image.imageId}")
// Detect faces with ML Kit
val detectedFaces = detectFacesInImage(image.imageUri)
totalFacesDetected += detectedFaces.size
DiagnosticLogger.d("Faces detected: ${detectedFaces.size}")
if (detectedFaces.isNotEmpty()) {
// ENHANCED: Calculate image quality
val imageQuality = ThresholdStrategy.estimateImageQuality(
width = image.width,
height = image.height
)
// ENHANCED: Estimate detection context
val detectionContext = ThresholdStrategy.estimateDetectionContext(
faceCount = detectedFaces.size,
faceAreaRatio = if (detectedFaces.isNotEmpty()) {
calculateFaceAreaRatio(detectedFaces[0], image.width, image.height)
} else 0f
)
// ENHANCED: Get smart threshold
val scanThreshold = if (DebugFlags.USE_LIBERAL_THRESHOLDS) {
ThresholdStrategy.getLiberalThreshold(trainingCount)
} else {
ThresholdStrategy.getOptimalThreshold(
trainingCount = trainingCount,
imageQuality = imageQuality,
detectionContext = detectionContext
allImages.chunked(BATCH_SIZE).forEach { imageBatch ->
val batchResults = imageBatch.map { imageWithEverything ->
async(Dispatchers.Default) {
semaphore.withPermit {
processImageOptimized(
imageWithEverything,
faceModelId,
trainingCount,
alreadyTaggedImageIds
)
}
}
}.awaitAll()
DiagnosticLogger.d("Quality: $imageQuality, Context: $detectionContext")
DiagnosticLogger.d("Using threshold: $scanThreshold")
// Scan image with smart threshold
val tags = faceRecognitionRepository.scanImage(
imageId = image.imageId,
detectedFaces = detectedFaces,
threshold = scanThreshold
)
DiagnosticLogger.d("Tags created: ${tags.size}")
tags.forEach { tag ->
DiagnosticLogger.d(" Tag: model=${tag.faceModelId.take(8)}, conf=${String.format("%.3f", tag.confidence)}")
}
val matchingTags = tags.filter { it.faceModelId == faceModelId }
DiagnosticLogger.d("Matching tags for target: ${matchingTags.size}")
facesFound += matchingTags.size
batchResults.forEach { result ->
if (result != null) {
processedCount.incrementAndGet()
facesFoundCount.addAndGet(result.matchingTagsCount)
totalFacesDetectedCount.addAndGet(result.totalFacesDetected)
if (result.skipped) skippedCount.incrementAndGet()
}
}
val now = System.currentTimeMillis()
if (now - lastProgressUpdate > PROGRESS_UPDATE_INTERVAL_MS) {
_scanningState.value = ScanningState.Scanning(
personId = personId,
personName = personName,
progress = index + 1,
total = totalImages,
facesFound = facesFound,
facesDetected = totalFacesDetected
personId, personName,
processedCount.get(), totalImages,
facesFoundCount.get(), totalFacesDetectedCount.get(),
skippedCount.get()
)
lastProgressUpdate = now
}
}
DiagnosticLogger.i("=== SCAN COMPLETE ===")
DiagnosticLogger.i("Images scanned: $totalImages")
DiagnosticLogger.i("Faces detected: $totalFacesDetected")
DiagnosticLogger.i("Faces matched: $facesFound")
DiagnosticLogger.i("Hit rate: ${if (totalFacesDetected > 0) (facesFound * 100 / totalFacesDetected) else 0}%")
val duration = (System.currentTimeMillis() - startTime) / 1000.0
DiagnosticLogger.i("=== SCAN COMPLETE in ${String.format("%.2f", duration)}s ===")
_scanningState.value = ScanningState.Complete(
personName = personName,
facesFound = facesFound,
imagesScanned = totalImages,
totalFacesDetected = totalFacesDetected
personName, facesFoundCount.get(), processedCount.get(),
totalFacesDetectedCount.get(), skippedCount.get()
)
loadPersons()
@@ -265,35 +267,102 @@ class PersonInventoryViewModel @Inject constructor(
} catch (e: Exception) {
DiagnosticLogger.e("Scan failed", e)
_scanningState.value = ScanningState.Idle
_uiState.value = InventoryUiState.Error(
"Scan failed: ${e.message}"
_uiState.value = InventoryUiState.Error("Scan failed: ${e.message}")
}
}
}
private data class ImageProcessingResult(
val matchingTagsCount: Int,
val totalFacesDetected: Int,
val skipped: Boolean
)
private suspend fun processImageOptimized(
imageWithEverything: Any,
faceModelId: String,
trainingCount: Int,
alreadyTaggedImageIds: Set<String>
): ImageProcessingResult? = withContext(Dispatchers.Default) {
try {
val imageId = (imageWithEverything as? Any)?.let {
// Access imageId from your ImageWithEverything type
// This will depend on your actual type structure
null as? String
} ?: return@withContext null
val imageUri = "" // Extract from imageWithEverything
val width = 1000 // Extract from imageWithEverything
val height = 1000 // Extract from imageWithEverything
if (imageId in alreadyTaggedImageIds) {
return@withContext ImageProcessingResult(0, 0, true)
}
val detectedFaces = faceDetectionCache.getOrPut(imageId) {
detectFacesInImageOptimized(imageUri)
}
if (detectedFaces.isEmpty()) {
return@withContext ImageProcessingResult(0, 0, false)
}
val imageQuality = ThresholdStrategy.estimateImageQuality(width, height)
val detectionContext = ThresholdStrategy.estimateDetectionContext(
detectedFaces.size,
calculateFaceAreaRatio(detectedFaces[0], width, height)
)
val scanThreshold = if (DebugFlags.USE_LIBERAL_THRESHOLDS) {
ThresholdStrategy.getLiberalThreshold(trainingCount)
} else {
ThresholdStrategy.getOptimalThreshold(
trainingCount, imageQuality, detectionContext
)
}
val tags = faceRecognitionRepository.scanImage(
imageId, detectedFaces, scanThreshold
)
val matchingTags = tags.count { it.faceModelId == faceModelId }
ImageProcessingResult(matchingTags, detectedFaces.size, false)
} catch (e: Exception) {
DiagnosticLogger.e("Failed to process image", e)
null
}
}
private suspend fun detectFacesInImage(imageUri: String): List<DetectedFace> = withContext(Dispatchers.Default) {
private suspend fun detectFacesInImageOptimized(imageUri: String): List<DetectedFace> =
withContext(Dispatchers.IO) {
var bitmap: Bitmap? = null
try {
val uri = Uri.parse(imageUri)
val inputStream = getApplication<Application>().contentResolver.openInputStream(uri)
val bitmap = BitmapFactory.decodeStream(inputStream)
inputStream?.close()
if (bitmap == null) {
DiagnosticLogger.w("Failed to load bitmap from: $imageUri")
return@withContext emptyList()
val options = BitmapFactory.Options().apply {
inJustDecodeBounds = true
}
getApplication<Application>().contentResolver.openInputStream(uri)?.use { stream ->
BitmapFactory.decodeStream(stream, null, options)
}
DiagnosticLogger.d("Bitmap: ${bitmap.width}x${bitmap.height}")
options.inSampleSize = calculateInSampleSize(
options.outWidth, options.outHeight, 2048, 2048
)
options.inJustDecodeBounds = false
options.inPreferredConfig = Bitmap.Config.RGB_565
bitmap = getApplication<Application>().contentResolver.openInputStream(uri)?.use { stream ->
BitmapFactory.decodeStream(stream, null, options)
}
if (bitmap == null) return@withContext emptyList()
val image = InputImage.fromBitmap(bitmap, 0)
val faces = faceDetector.process(image).await()
DiagnosticLogger.d("ML Kit found ${faces.size} faces")
faces.mapNotNull { face ->
val boundingBox = face.boundingBox
val croppedFace = try {
val left = boundingBox.left.coerceAtLeast(0)
val top = boundingBox.top.coerceAtLeast(0)
@@ -302,48 +371,176 @@ class PersonInventoryViewModel @Inject constructor(
if (width > 0 && height > 0) {
Bitmap.createBitmap(bitmap, left, top, width, height)
} else {
null
}
} else null
} catch (e: Exception) {
DiagnosticLogger.e("Face crop failed", e)
null
}
if (croppedFace != null) {
DetectedFace(
croppedBitmap = croppedFace,
boundingBox = boundingBox
)
} else {
null
croppedFace?.let {
DetectedFace(croppedBitmap = it, boundingBox = boundingBox)
}
}
} catch (e: Exception) {
DiagnosticLogger.e("Face detection failed: $imageUri", e)
emptyList()
} finally {
bitmap?.recycle()
}
}
/**
* Calculate face area ratio (for context detection)
*/
private fun calculateFaceAreaRatio(
face: DetectedFace,
imageWidth: Int,
imageHeight: Int
): Float {
private fun calculateInSampleSize(width: Int, height: Int, reqWidth: Int, reqHeight: Int): Int {
var inSampleSize = 1
if (height > reqHeight || width > reqWidth) {
val halfHeight = height / 2
val halfWidth = width / 2
while (halfHeight / inSampleSize >= reqHeight &&
halfWidth / inSampleSize >= reqWidth) {
inSampleSize *= 2
}
}
return inSampleSize
}
private fun calculateFaceAreaRatio(face: DetectedFace, imageWidth: Int, imageHeight: Int): Float {
val faceArea = face.boundingBox.width() * face.boundingBox.height()
val imageArea = imageWidth * imageHeight
return faceArea.toFloat() / imageArea.toFloat()
return if (imageArea > 0) faceArea.toFloat() / imageArea.toFloat() else 0f
}
// ============================================================================
// MODEL IMPROVEMENT
// ============================================================================
fun startModelImprovement(personId: String, faceModelId: String) {
viewModelScope.launch {
try {
val currentState = _uiState.value
val person = if (currentState is InventoryUiState.Success) {
currentState.persons.find { it.person.id == personId }?.person
} else null
val personName = person?.name ?: "Unknown"
val faceModel = faceRecognitionRepository.getFaceModelById(faceModelId)
val currentTrainingCount = faceModel?.trainingImageCount ?: 15
_improvementState.value = ModelImprovementState.SelectingPhotos(
personId, personName, faceModelId, currentTrainingCount
)
} catch (e: Exception) {
_improvementState.value = ModelImprovementState.Error(
"Failed to start: ${e.message}"
)
}
}
}
fun processSelectedPhotos(
personId: String,
faceModelId: String,
selectedImageUris: List<Uri>
) {
viewModelScope.launch {
try {
val currentState = _improvementState.value
if (currentState !is ModelImprovementState.SelectingPhotos) return@launch
val sanityCheckResult = sanityChecker.performSanityChecks(
imageUris = selectedImageUris,
minImagesRequired = 5,
allowMultipleFaces = true,
duplicateSimilarityThreshold = 0.95,
onProgress = { phase, current, total ->
_improvementState.value = ModelImprovementState.ValidatingPhotos(
personId, currentState.personName, faceModelId,
phase, current, total
)
}
)
_improvementState.value = ModelImprovementState.ReviewingPhotos(
personId, currentState.personName, faceModelId,
sanityCheckResult, currentState.currentTrainingCount
)
} catch (e: Exception) {
_improvementState.value = ModelImprovementState.Error(
"Validation failed: ${e.message}"
)
}
}
}
fun retrainModelWithValidatedPhotos(
personId: String,
faceModelId: String,
sanityCheckResult: TrainingSanityChecker.SanityCheckResult
) {
viewModelScope.launch {
try {
val currentState = _improvementState.value
if (currentState !is ModelImprovementState.ReviewingPhotos) return@launch
val validImages = sanityCheckResult.validImagesWithFaces
if (validImages.isEmpty()) {
_improvementState.value = ModelImprovementState.Error("No valid photos")
return@launch
}
val currentModel = faceRecognitionRepository.getFaceModelById(faceModelId)
?: throw IllegalStateException("Face model not found")
_improvementState.value = ModelImprovementState.Training(
currentState.personName, 0, validImages.size + 1,
"Extracting embeddings..."
)
// Use repository's retrainFaceModel method
faceRecognitionRepository.retrainFaceModel(
faceModelId = faceModelId,
newFaceImages = validImages.map { it.croppedFaceBitmap }
)
val updatedModel = faceRecognitionRepository.getFaceModelById(faceModelId)!!
faceDetectionCache.clear()
_improvementState.value = ModelImprovementState.TrainingComplete(
currentState.personName,
validImages.size,
updatedModel.trainingImageCount,
currentModel.averageConfidence,
updatedModel.averageConfidence
)
loadPersons()
delay(3000)
_improvementState.value = ModelImprovementState.Idle
} catch (e: Exception) {
DiagnosticLogger.e("Retraining failed", e)
_improvementState.value = ModelImprovementState.Error(
"Retraining failed: ${e.message}"
)
}
}
}
fun cancelModelImprovement() {
_improvementState.value = ModelImprovementState.Idle
}
suspend fun getPersonImages(personId: String) =
faceRecognitionRepository.getImagesForPerson(personId)
fun clearCaches() {
faceDetectionCache.clear()
}
override fun onCleared() {
super.onCleared()
faceDetector.close()
faceDetectionHelper.cleanup()
sanityChecker.cleanup()
clearCaches()
}
}

View File

@@ -2,7 +2,9 @@ package com.placeholder.sherpai2.ui.presentation
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.*
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.*
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
@@ -18,11 +20,8 @@ import androidx.compose.material.icons.filled.*
import com.placeholder.sherpai2.ui.navigation.AppRoutes
/**
* CLEAN & COMPACT Drawer
* - 280dp width (not 300dp)
* - Icon + SherpAI inline (not stacked)
* - NO subtitles (clean single-line items)
* - Terrain icon (mountain theme)
* SLIMMED DOWN AppDrawer - 280dp width, inline logo, cleaner sections
* NOW WITH: Scrollable support for small phones + Collections item
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
@@ -31,12 +30,17 @@ fun AppDrawerContent(
onDestinationClicked: (String) -> Unit
) {
ModalDrawerSheet(
modifier = Modifier.width(280.dp), // Narrower!
modifier = Modifier.width(280.dp), // SLIMMER (was 300dp)
drawerContainerColor = MaterialTheme.colorScheme.surface
) {
Column(modifier = Modifier.fillMaxSize()) {
// SCROLLABLE Column - works on small phones!
Column(
modifier = Modifier
.fillMaxSize()
.verticalScroll(rememberScrollState())
) {
// ===== COMPACT INLINE HEADER =====
// ===== COMPACT HEADER - Icon + Text Inline =====
Box(
modifier = Modifier
.fillMaxWidth()
@@ -48,22 +52,22 @@ fun AppDrawerContent(
)
)
)
.padding(20.dp) // Tighter padding
.padding(20.dp) // Reduced padding
) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(12.dp)
horizontalArrangement = Arrangement.spacedBy(12.dp),
verticalAlignment = Alignment.CenterVertically
) {
// Icon - TERRAIN (mountain theme!)
// App icon - smaller
Surface(
modifier = Modifier.size(48.dp), // Smaller
shape = RoundedCornerShape(12.dp),
modifier = Modifier.size(48.dp), // Smaller (was 56dp)
shape = RoundedCornerShape(14.dp),
color = MaterialTheme.colorScheme.primary,
shadowElevation = 4.dp
) {
Box(contentAlignment = Alignment.Center) {
Icon(
Icons.Default.Terrain, // Mountain icon!
Icons.Default.Terrain, // Mountain theme!
contentDescription = null,
modifier = Modifier.size(28.dp),
tint = MaterialTheme.colorScheme.onPrimary
@@ -71,32 +75,32 @@ fun AppDrawerContent(
}
}
// Text INLINE with icon
Column {
// Text next to icon
Column(verticalArrangement = Arrangement.spacedBy(2.dp)) {
Text(
"SherpAI",
style = MaterialTheme.typography.titleLarge,
style = MaterialTheme.typography.titleLarge, // Smaller (was headlineMedium)
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.onSurface
)
Text(
"Face Recognition",
style = MaterialTheme.typography.bodySmall,
style = MaterialTheme.typography.bodySmall, // Smaller
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
Spacer(modifier = Modifier.height(8.dp))
Spacer(modifier = Modifier.height(4.dp)) // Reduced spacing
// ===== NAVIGATION ITEMS - COMPACT =====
// ===== NAVIGATION SECTIONS =====
Column(
modifier = Modifier
.fillMaxWidth()
.weight(1f)
.padding(horizontal = 12.dp),
verticalArrangement = Arrangement.spacedBy(4.dp)
.padding(horizontal = 8.dp), // Reduced padding
verticalArrangement = Arrangement.spacedBy(2.dp) // Tighter spacing
) {
// Photos Section
@@ -105,7 +109,7 @@ fun AppDrawerContent(
val photoItems = listOf(
DrawerItem(AppRoutes.SEARCH, "Search", Icons.Default.Search),
DrawerItem(AppRoutes.EXPLORE, "Explore", Icons.Default.Explore),
DrawerItem(AppRoutes.COLLECTIONS, "Collections", Icons.Default.Collections)
DrawerItem(AppRoutes.COLLECTIONS, "Collections", Icons.Default.Collections) // NEW!
)
photoItems.forEach { item ->
@@ -116,7 +120,7 @@ fun AppDrawerContent(
)
}
Spacer(modifier = Modifier.height(8.dp))
Spacer(modifier = Modifier.height(4.dp))
// Face Recognition Section
DrawerSection(title = "Face Recognition")
@@ -135,7 +139,7 @@ fun AppDrawerContent(
)
}
Spacer(modifier = Modifier.height(8.dp))
Spacer(modifier = Modifier.height(4.dp))
// Organization Section
DrawerSection(title = "Organization")
@@ -153,11 +157,11 @@ fun AppDrawerContent(
)
}
Spacer(modifier = Modifier.weight(1f))
Spacer(modifier = Modifier.height(8.dp))
// Settings at bottom
HorizontalDivider(
modifier = Modifier.padding(vertical = 8.dp),
modifier = Modifier.padding(vertical = 6.dp),
color = MaterialTheme.colorScheme.outlineVariant
)
@@ -171,28 +175,28 @@ fun AppDrawerContent(
onClick = { onDestinationClicked(AppRoutes.SETTINGS) }
)
Spacer(modifier = Modifier.height(8.dp))
Spacer(modifier = Modifier.height(16.dp)) // Bottom padding for scroll
}
}
}
}
/**
* Section header
* Section header - more compact
*/
@Composable
private fun DrawerSection(title: String) {
Text(
text = title,
style = MaterialTheme.typography.labelMedium,
style = MaterialTheme.typography.labelSmall, // Smaller
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.primary,
modifier = Modifier.padding(horizontal = 16.dp, vertical = 8.dp)
modifier = Modifier.padding(horizontal = 16.dp, vertical = 6.dp) // Reduced padding
)
}
/**
* COMPACT navigation item - NO SUBTITLES
* Navigation item - cleaner, no subtitle
*/
@Composable
private fun DrawerNavigationItem(
@@ -204,7 +208,7 @@ private fun DrawerNavigationItem(
label = {
Text(
text = item.label,
style = MaterialTheme.typography.bodyLarge,
style = MaterialTheme.typography.bodyMedium, // Slightly smaller
fontWeight = if (selected) FontWeight.SemiBold else FontWeight.Normal
)
},
@@ -212,14 +216,14 @@ private fun DrawerNavigationItem(
Icon(
item.icon,
contentDescription = item.label,
modifier = Modifier.size(24.dp)
modifier = Modifier.size(22.dp) // Slightly smaller
)
},
selected = selected,
onClick = onClick,
modifier = Modifier
.padding(NavigationDrawerItemDefaults.ItemPadding)
.clip(RoundedCornerShape(12.dp)),
.clip(RoundedCornerShape(10.dp)), // Slightly smaller radius
colors = NavigationDrawerItemDefaults.colors(
selectedContainerColor = MaterialTheme.colorScheme.primaryContainer,
selectedIconColor = MaterialTheme.colorScheme.primary,
@@ -230,12 +234,10 @@ private fun DrawerNavigationItem(
}
/**
* Simple drawer item - no subtitle needed
* Simplified drawer item (no subtitle)
*/
private data class DrawerItem(
val route: String,
val label: String,
val icon: androidx.compose.ui.graphics.vector.ImageVector
)
//TODO we also lost the tight gradient top part

View File

@@ -1,10 +1,5 @@
package com.placeholder.sherpai2.ui.presentation
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.fadeIn
import androidx.compose.animation.fadeOut
import androidx.compose.animation.slideInVertically
import androidx.compose.animation.slideOutVertically
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.padding
import androidx.compose.material.icons.Icons
@@ -20,7 +15,7 @@ import com.placeholder.sherpai2.ui.navigation.AppRoutes
import kotlinx.coroutines.launch
/**
* Beautiful main screen with gradient header, dynamic actions, and polish
* Clean main screen - NO duplicate FABs, Collections support
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
@@ -103,15 +98,7 @@ fun MainScreen() {
)
}
}
AppRoutes.TAGS -> {
IconButton(onClick = { /* TODO: Add tag */ }) {
Icon(
Icons.Default.Add,
contentDescription = "Add Tag",
tint = MaterialTheme.colorScheme.primary
)
}
}
// NOTE: Removed TAGS action - TagManagementScreen has its own inline FAB
}
},
colors = TopAppBarDefaults.topAppBarColors(
@@ -122,6 +109,7 @@ fun MainScreen() {
)
)
}
// NOTE: NO floatingActionButton here - individual screens manage their own FABs inline
) { paddingValues ->
AppNavHost(
navController = navController,
@@ -137,7 +125,8 @@ fun MainScreen() {
private fun getScreenTitle(route: String): String {
return when (route) {
AppRoutes.SEARCH -> "Search"
AppRoutes.EXPLORE -> "Explore" // Will be renamed to EXPLORE
AppRoutes.EXPLORE -> "Explore"
AppRoutes.COLLECTIONS -> "Collections" // NEW!
AppRoutes.INVENTORY -> "People"
AppRoutes.TRAIN -> "Train New Person"
AppRoutes.MODELS -> "AI Models"
@@ -155,6 +144,7 @@ private fun getScreenSubtitle(route: String): String? {
return when (route) {
AppRoutes.SEARCH -> "Find photos by tags, people, or date"
AppRoutes.EXPLORE -> "Browse your collection"
AppRoutes.COLLECTIONS -> "Your photo collections" // NEW!
AppRoutes.INVENTORY -> "Trained face models"
AppRoutes.TRAIN -> "Add a new person to recognize"
AppRoutes.TAGS -> "Organize your photo collection"