diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Duplicateimagedetector.kt b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Duplicateimagedetector.kt new file mode 100644 index 0000000..66b225b --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Duplicateimagedetector.kt @@ -0,0 +1,159 @@ +package com.placeholder.sherpai2.ui.trainingprep + +import android.content.Context +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import android.net.Uri +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import java.io.InputStream + +/** + * Helper class for detecting duplicate or near-duplicate images using perceptual hashing + */ +class DuplicateImageDetector(private val context: Context) { + + data class DuplicateCheckResult( + val hasDuplicates: Boolean, + val duplicateGroups: List, + val uniqueImageCount: Int + ) + + data class DuplicateGroup( + val images: List, + val similarity: Double + ) + + private data class ImageHash( + val uri: Uri, + val hash: Long + ) + + /** + * Check for duplicate images in the provided list + */ + suspend fun checkForDuplicates( + uris: List, + similarityThreshold: Double = 0.95 + ): DuplicateCheckResult = withContext(Dispatchers.Default) { + if (uris.size < 2) { + return@withContext DuplicateCheckResult( + hasDuplicates = false, + duplicateGroups = emptyList(), + uniqueImageCount = uris.size + ) + } + + // Compute perceptual hash for each image + val imageHashes = uris.mapNotNull { uri -> + try { + val bitmap = loadBitmap(uri) + bitmap?.let { + val hash = computePerceptualHash(it) + ImageHash(uri, hash) + } + } catch (e: Exception) { + null + } + } + + // Find duplicate groups + val duplicateGroups = mutableListOf() + val processed = mutableSetOf() + + for (i in imageHashes.indices) { + if (imageHashes[i].uri in processed) continue + + val currentGroup = mutableListOf(imageHashes[i].uri) + + for (j in i + 1 until imageHashes.size) { + if (imageHashes[j].uri in processed) continue + + val similarity = calculateSimilarity(imageHashes[i].hash, imageHashes[j].hash) + + if (similarity >= similarityThreshold) { + currentGroup.add(imageHashes[j].uri) + processed.add(imageHashes[j].uri) + } + } + + if (currentGroup.size > 1) { + duplicateGroups.add( + DuplicateGroup( + images = currentGroup, + similarity = 1.0 + ) + ) + processed.addAll(currentGroup) + } + } + + DuplicateCheckResult( + hasDuplicates = duplicateGroups.isNotEmpty(), + duplicateGroups = duplicateGroups, + uniqueImageCount = uris.size - duplicateGroups.sumOf { it.images.size - 1 } + ) + } + + /** + * Compute perceptual hash using difference hash (dHash) algorithm + */ + private fun computePerceptualHash(bitmap: Bitmap): Long { + // Resize to 9x8 + val resized = Bitmap.createScaledBitmap(bitmap, 9, 8, false) + + var hash = 0L + var bitIndex = 0 + + for (y in 0 until 8) { + for (x in 0 until 8) { + val leftPixel = resized.getPixel(x, y) + val rightPixel = resized.getPixel(x + 1, y) + + val leftGray = toGrayscale(leftPixel) + val rightGray = toGrayscale(rightPixel) + + if (leftGray > rightGray) { + hash = hash or (1L shl bitIndex) + } + bitIndex++ + } + } + + resized.recycle() + return hash + } + + /** + * Convert RGB pixel to grayscale value + */ + private fun toGrayscale(pixel: Int): Int { + val r = (pixel shr 16) and 0xFF + val g = (pixel shr 8) and 0xFF + val b = pixel and 0xFF + return (0.299 * r + 0.587 * g + 0.114 * b).toInt() + } + + /** + * Calculate similarity between two hashes + */ + private fun calculateSimilarity(hash1: Long, hash2: Long): Double { + val xor = hash1 xor hash2 + val hammingDistance = xor.countOneBits() + return 1.0 - (hammingDistance / 64.0) + } + + /** + * Load bitmap from URI + */ + private fun loadBitmap(uri: Uri): Bitmap? { + return try { + val inputStream: InputStream? = context.contentResolver.openInputStream(uri) + BitmapFactory.decodeStream(inputStream)?.also { + inputStream?.close() + } + } catch (e: Exception) { + null + } + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/FacePickerDialog.kt b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/FacePickerDialog.kt new file mode 100644 index 0000000..0db70b7 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/FacePickerDialog.kt @@ -0,0 +1,435 @@ +package com.placeholder.sherpai2.ui.trainingprep + +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import android.graphics.Rect +import android.net.Uri +import androidx.compose.foundation.BorderStroke +import androidx.compose.foundation.Canvas +import androidx.compose.foundation.background +import androidx.compose.foundation.border +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.CheckCircle +import androidx.compose.material.icons.filled.Close +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.geometry.Offset +import androidx.compose.ui.geometry.Size +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.graphics.asImageBitmap +import androidx.compose.ui.graphics.drawscope.Stroke +import androidx.compose.ui.layout.ContentScale +import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.unit.dp +import androidx.compose.ui.window.Dialog +import androidx.compose.ui.window.DialogProperties +import coil.compose.AsyncImage +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext + +/** + * Dialog for selecting a face from multiple detected faces + */ +@Composable +fun FacePickerDialog( + result: FaceDetectionHelper.FaceDetectionResult, + onDismiss: () -> Unit, + onFaceSelected: (Int, Bitmap) -> Unit // faceIndex, croppedFaceBitmap +) { + val context = LocalContext.current + var selectedFaceIndex by remember { mutableStateOf(null) } + var croppedFaces by remember { mutableStateOf>(emptyList()) } + var isLoading by remember { mutableStateOf(true) } + + // Load and crop all faces + LaunchedEffect(result) { + isLoading = true + croppedFaces = withContext(Dispatchers.IO) { + val bitmap = loadBitmapFromUri(context, result.uri) + bitmap?.let { bmp -> + result.faceBounds.map { bounds -> + cropFaceFromBitmap(bmp, bounds) + } + } ?: emptyList() + } + isLoading = false + // Auto-select the first (largest) face + if (croppedFaces.isNotEmpty()) { + selectedFaceIndex = 0 + } + } + + Dialog( + onDismissRequest = onDismiss, + properties = DialogProperties(usePlatformDefaultWidth = false) + ) { + Card( + modifier = Modifier + .fillMaxWidth(0.95f) + .fillMaxHeight(0.9f), + shape = RoundedCornerShape(16.dp) + ) { + Column( + modifier = Modifier + .fillMaxSize() + .padding(20.dp), + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + // Header + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Column { + Text( + text = "Pick a Face", + style = MaterialTheme.typography.headlineSmall, + fontWeight = FontWeight.Bold + ) + Text( + text = "${result.faceCount} faces detected", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + IconButton(onClick = onDismiss) { + Icon(Icons.Default.Close, "Close") + } + } + + // Instruction + Text( + text = "Tap a face below to select it for training:", + style = MaterialTheme.typography.bodyMedium + ) + + if (isLoading) { + // Loading state + Box( + modifier = Modifier + .fillMaxWidth() + .weight(1f), + contentAlignment = Alignment.Center + ) { + CircularProgressIndicator() + } + } else { + // Original image with face boxes overlay + Card( + modifier = Modifier + .fillMaxWidth() + .weight(1f), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.surfaceVariant + ) + ) { + Box( + modifier = Modifier.fillMaxSize(), + contentAlignment = Alignment.Center + ) { + FaceOverlayImage( + imageUri = result.uri, + faceBounds = result.faceBounds, + selectedFaceIndex = selectedFaceIndex, + onFaceClick = { index -> + selectedFaceIndex = index + } + ) + } + } + + // Face previews grid + Text( + text = "Preview (tap to select):", + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.SemiBold + ) + + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(12.dp) + ) { + croppedFaces.forEachIndexed { index, faceBitmap -> + FacePreviewCard( + faceBitmap = faceBitmap, + index = index, + isSelected = selectedFaceIndex == index, + onClick = { selectedFaceIndex = index }, + modifier = Modifier.weight(1f) + ) + } + } + } + + // Action buttons + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(12.dp) + ) { + OutlinedButton( + onClick = onDismiss, + modifier = Modifier.weight(1f) + ) { + Text("Cancel") + } + + Button( + onClick = { + selectedFaceIndex?.let { index -> + if (index < croppedFaces.size) { + onFaceSelected(index, croppedFaces[index]) + } + } + }, + modifier = Modifier.weight(1f), + enabled = selectedFaceIndex != null && !isLoading + ) { + Icon(Icons.Default.CheckCircle, contentDescription = null) + Spacer(modifier = Modifier.width(8.dp)) + Text("Use This Face") + } + } + } + } + } +} + +/** + * Image with interactive face boxes overlay + */ +@Composable +private fun FaceOverlayImage( + imageUri: Uri, + faceBounds: List, + selectedFaceIndex: Int?, + onFaceClick: (Int) -> Unit +) { + var imageSize by remember { mutableStateOf(Size.Zero) } + var imageBounds by remember { mutableStateOf(Rect()) } + + Box( + modifier = Modifier.fillMaxSize() + ) { + // Original image + AsyncImage( + model = imageUri, + contentDescription = "Original image", + modifier = Modifier + .fillMaxSize() + .padding(8.dp), + contentScale = ContentScale.Fit, + onSuccess = { state -> + val drawable = state.result.drawable + imageBounds = Rect(0, 0, drawable.intrinsicWidth, drawable.intrinsicHeight) + } + ) + + // Face boxes overlay + Canvas( + modifier = Modifier + .fillMaxSize() + .padding(8.dp) + ) { + if (imageBounds.width() > 0 && imageBounds.height() > 0) { + // Calculate scale to fit image in canvas + val scaleX = size.width / imageBounds.width() + val scaleY = size.height / imageBounds.height() + val scale = minOf(scaleX, scaleY) + + // Calculate offset to center image + val scaledWidth = imageBounds.width() * scale + val scaledHeight = imageBounds.height() * scale + val offsetX = (size.width - scaledWidth) / 2 + val offsetY = (size.height - scaledHeight) / 2 + + faceBounds.forEachIndexed { index, bounds -> + val isSelected = selectedFaceIndex == index + + // Scale and position the face box + val left = bounds.left * scale + offsetX + val top = bounds.top * scale + offsetY + val width = bounds.width() * scale + val height = bounds.height() * scale + + // Draw box + drawRect( + color = if (isSelected) Color(0xFF4CAF50) else Color(0xFF2196F3), + topLeft = Offset(left, top), + size = Size(width, height), + style = Stroke(width = if (isSelected) 6f else 4f) + ) + + // Draw semi-transparent fill for selected + if (isSelected) { + drawRect( + color = Color(0xFF4CAF50).copy(alpha = 0.2f), + topLeft = Offset(left, top), + size = Size(width, height) + ) + } + + // Draw face number label + drawCircle( + color = if (isSelected) Color(0xFF4CAF50) else Color(0xFF2196F3), + radius = 20f * scale, + center = Offset(left + 20f * scale, top + 20f * scale) + ) + } + } + } + + // Clickable areas for each face + faceBounds.forEachIndexed { index, bounds -> + if (imageBounds.width() > 0 && imageBounds.height() > 0) { + val scaleX = imageSize.width / imageBounds.width() + val scaleY = imageSize.height / imageBounds.height() + val scale = minOf(scaleX, scaleY) + + val scaledWidth = imageBounds.width() * scale + val scaledHeight = imageBounds.height() * scale + val offsetX = (imageSize.width - scaledWidth) / 2 + val offsetY = (imageSize.height - scaledHeight) / 2 + + Box( + modifier = Modifier + .fillMaxSize() + .clickable { onFaceClick(index) } + ) + } + } + } + + // Update image size + BoxWithConstraints { + LaunchedEffect(constraints) { + imageSize = Size(constraints.maxWidth.toFloat(), constraints.maxHeight.toFloat()) + } + } +} + +/** + * Individual face preview card + */ +@Composable +private fun FacePreviewCard( + faceBitmap: Bitmap, + index: Int, + isSelected: Boolean, + onClick: () -> Unit, + modifier: Modifier = Modifier +) { + Card( + modifier = modifier + .aspectRatio(1f) + .clickable(onClick = onClick), + colors = CardDefaults.cardColors( + containerColor = if (isSelected) + MaterialTheme.colorScheme.primaryContainer + else + MaterialTheme.colorScheme.surface + ), + border = if (isSelected) + BorderStroke(3.dp, MaterialTheme.colorScheme.primary) + else + BorderStroke(1.dp, MaterialTheme.colorScheme.outline) + ) { + Box( + modifier = Modifier.fillMaxSize() + ) { + androidx.compose.foundation.Image( + bitmap = faceBitmap.asImageBitmap(), + contentDescription = "Face ${index + 1}", + modifier = Modifier.fillMaxSize(), + contentScale = ContentScale.Crop + ) + + // Selected checkmark (only show when selected) + if (isSelected) { + Surface( + modifier = Modifier + .align(Alignment.Center), + shape = CircleShape, + color = MaterialTheme.colorScheme.primary.copy(alpha = 0.9f) + ) { + Icon( + Icons.Default.CheckCircle, + contentDescription = "Selected", + modifier = Modifier + .padding(12.dp) + .size(32.dp), + tint = MaterialTheme.colorScheme.onPrimary + ) + } + } + + // Face number badge (always in top-right, small) + Surface( + modifier = Modifier + .align(Alignment.TopEnd) + .padding(4.dp), + shape = CircleShape, + color = if (isSelected) + MaterialTheme.colorScheme.primary + else + MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.9f), + shadowElevation = 2.dp + ) { + Text( + text = "${index + 1}", + modifier = Modifier.padding(6.dp), + style = MaterialTheme.typography.labelSmall, + fontWeight = FontWeight.Bold, + color = if (isSelected) + MaterialTheme.colorScheme.onPrimary + else + MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } +} + +/** + * Helper function to load bitmap from URI + */ +private suspend fun loadBitmapFromUri( + context: android.content.Context, + uri: Uri +): Bitmap? = withContext(Dispatchers.IO) { + try { + val inputStream = context.contentResolver.openInputStream(uri) + BitmapFactory.decodeStream(inputStream)?.also { + inputStream?.close() + } + } catch (e: Exception) { + null + } +} + +/** + * Helper function to crop face from bitmap + */ +private fun cropFaceFromBitmap(bitmap: Bitmap, faceBounds: Rect): Bitmap { + // Add 20% padding around the face + val padding = (faceBounds.width() * 0.2f).toInt() + + val left = (faceBounds.left - padding).coerceAtLeast(0) + val top = (faceBounds.top - padding).coerceAtLeast(0) + val right = (faceBounds.right + padding).coerceAtMost(bitmap.width) + val bottom = (faceBounds.bottom + padding).coerceAtMost(bitmap.height) + + val width = right - left + val height = bottom - top + + return Bitmap.createBitmap(bitmap, left, top, width, height) +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Facedetectionhelper.kt b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Facedetectionhelper.kt new file mode 100644 index 0000000..b84c1e4 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Facedetectionhelper.kt @@ -0,0 +1,124 @@ +package com.placeholder.sherpai2.ui.trainingprep + +import android.content.Context +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import android.graphics.Rect +import android.net.Uri +import com.google.mlkit.vision.common.InputImage +import com.google.mlkit.vision.face.FaceDetection +import com.google.mlkit.vision.face.FaceDetectorOptions +import kotlinx.coroutines.tasks.await +import java.io.InputStream + +/** + * Helper class for detecting faces in images using ML Kit Face Detection + */ +class FaceDetectionHelper(private val context: Context) { + + private val faceDetectorOptions = FaceDetectorOptions.Builder() + .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) + .setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL) + .setClassificationMode(FaceDetectorOptions.CLASSIFICATION_MODE_ALL) + .setMinFaceSize(0.15f) // Detect faces that are at least 15% of image + .build() + + private val detector = FaceDetection.getClient(faceDetectorOptions) + + data class FaceDetectionResult( + val uri: Uri, + val hasFace: Boolean, + val faceCount: Int, + val faceBounds: List = emptyList(), + val croppedFaceBitmap: Bitmap? = null, + val errorMessage: String? = null + ) + + /** + * Detect faces in a single image + */ + suspend fun detectFacesInImage(uri: Uri): FaceDetectionResult { + return try { + val bitmap = loadBitmap(uri) + if (bitmap == null) { + return FaceDetectionResult( + uri = uri, + hasFace = false, + faceCount = 0, + errorMessage = "Failed to load image" + ) + } + + val inputImage = InputImage.fromBitmap(bitmap, 0) + val faces = detector.process(inputImage).await() + + val croppedFace = if (faces.isNotEmpty()) { + // Crop the first detected face with some padding + cropFaceFromBitmap(bitmap, faces[0].boundingBox) + } else null + + FaceDetectionResult( + uri = uri, + hasFace = faces.isNotEmpty(), + faceCount = faces.size, + faceBounds = faces.map { it.boundingBox }, + croppedFaceBitmap = croppedFace + ) + } catch (e: Exception) { + FaceDetectionResult( + uri = uri, + hasFace = false, + faceCount = 0, + errorMessage = e.message ?: "Unknown error" + ) + } + } + + /** + * Detect faces in multiple images + */ + suspend fun detectFacesInImages(uris: List): List { + return uris.map { uri -> + detectFacesInImage(uri) + } + } + + /** + * Crop face from bitmap with padding + */ + private fun cropFaceFromBitmap(bitmap: Bitmap, faceBounds: Rect): Bitmap { + // Add 20% padding around the face + val padding = (faceBounds.width() * 0.2f).toInt() + + val left = (faceBounds.left - padding).coerceAtLeast(0) + val top = (faceBounds.top - padding).coerceAtLeast(0) + val right = (faceBounds.right + padding).coerceAtMost(bitmap.width) + val bottom = (faceBounds.bottom + padding).coerceAtMost(bitmap.height) + + val width = right - left + val height = bottom - top + + return Bitmap.createBitmap(bitmap, left, top, width, height) + } + + /** + * Load bitmap from URI + */ + private fun loadBitmap(uri: Uri): Bitmap? { + return try { + val inputStream: InputStream? = context.contentResolver.openInputStream(uri) + BitmapFactory.decodeStream(inputStream)?.also { + inputStream?.close() + } + } catch (e: Exception) { + null + } + } + + /** + * Clean up resources + */ + fun cleanup() { + detector.close() + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/ScanResultsScreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/ScanResultsScreen.kt index 9b61b5b..d30ee99 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/ScanResultsScreen.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/ScanResultsScreen.kt @@ -1,29 +1,629 @@ package com.placeholder.sherpai2.ui.trainingprep import android.net.Uri +import androidx.activity.compose.rememberLauncherForActivityResult +import androidx.activity.result.PickVisualMediaRequest +import androidx.activity.result.contract.ActivityResultContracts +import androidx.compose.foundation.BorderStroke import androidx.compose.foundation.Image import androidx.compose.foundation.background +import androidx.compose.foundation.border +import androidx.compose.foundation.clickable import androidx.compose.foundation.layout.* import androidx.compose.foundation.lazy.LazyColumn -import androidx.compose.foundation.lazy.items // CRITICAL: This is the correct import for List items +import androidx.compose.foundation.lazy.itemsIndexed +import androidx.compose.foundation.shape.CircleShape import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.* import androidx.compose.material3.* -import androidx.compose.runtime.Composable +import androidx.compose.runtime.* import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.graphics.asImageBitmap import androidx.compose.ui.layout.ContentScale +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.unit.dp -import coil.compose.rememberAsyncImagePainter -import androidx.compose.foundation.lazy.items +import androidx.hilt.navigation.compose.hiltViewModel +import coil.compose.AsyncImage -/** - * Displays the outcome of the face detection process. - */ +@OptIn(ExperimentalMaterial3Api::class) @Composable fun ScanResultsScreen( state: ScanningState, - onFinish: () -> Unit + onFinish: () -> Unit, + trainViewModel: TrainViewModel = hiltViewModel() +) { + var showFacePickerDialog by remember { mutableStateOf(null) } + + Scaffold( + topBar = { + TopAppBar( + title = { Text("Training Image Analysis") }, + colors = TopAppBarDefaults.topAppBarColors( + containerColor = MaterialTheme.colorScheme.primaryContainer + ) + ) + } + ) { paddingValues -> + Box( + modifier = Modifier + .fillMaxSize() + .padding(paddingValues) + ) { + when (state) { + is ScanningState.Idle -> { + // Should not happen + } + + is ScanningState.Processing -> { + ProcessingView( + progress = state.progress, + total = state.total + ) + } + + is ScanningState.Success -> { + ImprovedResultsView( + result = state.sanityCheckResult, + onContinue = onFinish, + onRetry = onFinish, + onReplaceImage = { oldUri, newUri -> + trainViewModel.replaceImage(oldUri, newUri) + }, + onSelectFaceFromMultiple = { result -> + showFacePickerDialog = result + } + ) + } + + is ScanningState.Error -> { + ErrorView( + message = state.message, + onRetry = onFinish + ) + } + } + } + } + + // Face Picker Dialog + showFacePickerDialog?.let { result -> + FacePickerDialog( + result = result, + onDismiss = { showFacePickerDialog = null }, + onFaceSelected = { faceIndex, croppedFaceBitmap -> + trainViewModel.selectFaceFromImage(result.uri, faceIndex, croppedFaceBitmap) + showFacePickerDialog = null + } + ) + } +} + +@Composable +private fun ProcessingView(progress: Int, total: Int) { + Column( + modifier = Modifier.fillMaxSize(), + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center + ) { + CircularProgressIndicator( + modifier = Modifier.size(64.dp), + strokeWidth = 6.dp + ) + Spacer(modifier = Modifier.height(24.dp)) + Text( + text = "Analyzing images...", + style = MaterialTheme.typography.titleMedium + ) + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Detecting faces and checking for duplicates", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + if (total > 0) { + Spacer(modifier = Modifier.height(16.dp)) + LinearProgressIndicator( + progress = { (progress.toFloat() / total.toFloat()).coerceIn(0f, 1f) }, + modifier = Modifier.width(200.dp) + ) + Text( + text = "$progress / $total", + style = MaterialTheme.typography.bodySmall + ) + } + } +} + +@Composable +private fun ImprovedResultsView( + result: TrainingSanityChecker.SanityCheckResult, + onContinue: () -> Unit, + onRetry: () -> Unit, + onReplaceImage: (Uri, Uri) -> Unit, + onSelectFaceFromMultiple: (FaceDetectionHelper.FaceDetectionResult) -> Unit +) { + LazyColumn( + modifier = Modifier.fillMaxSize(), + contentPadding = PaddingValues(16.dp), + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + // Welcome Header + item { + Card( + modifier = Modifier.fillMaxWidth(), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.secondaryContainer + ) + ) { + Column( + modifier = Modifier.padding(16.dp) + ) { + Text( + text = "Analysis Complete!", + style = MaterialTheme.typography.headlineSmall, + fontWeight = FontWeight.Bold + ) + Spacer(modifier = Modifier.height(4.dp)) + Text( + text = "Review your images below. Tap 'Pick Face' on group photos to choose which person to train on, or 'Replace' to swap out any image.", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSecondaryContainer.copy(alpha = 0.8f) + ) + } + } + } + + // Progress Summary + item { + ProgressSummaryCard( + totalImages = result.faceDetectionResults.size, + validImages = result.validImagesWithFaces.size, + requiredImages = 10, + isValid = result.isValid + ) + } + + // Image List Header + item { + Text( + text = "Your Images (${result.faceDetectionResults.size})", + style = MaterialTheme.typography.titleLarge, + fontWeight = FontWeight.Bold + ) + } + + // Image List with Actions + itemsIndexed(result.faceDetectionResults) { index, imageResult -> + ImageResultCard( + index = index + 1, + result = imageResult, + onReplace = { newUri -> + onReplaceImage(imageResult.uri, newUri) + }, + onSelectFace = if (imageResult.faceCount > 1) { + { onSelectFaceFromMultiple(imageResult) } + } else null + ) + } + + // Validation Issues (if any) + if (result.validationErrors.isNotEmpty()) { + item { + Spacer(modifier = Modifier.height(8.dp)) + ValidationIssuesCard(errors = result.validationErrors) + } + } + + // Action Button + item { + Spacer(modifier = Modifier.height(8.dp)) + Button( + onClick = if (result.isValid) onContinue else onRetry, + modifier = Modifier.fillMaxWidth(), + enabled = result.isValid, + colors = ButtonDefaults.buttonColors( + containerColor = if (result.isValid) + MaterialTheme.colorScheme.primary + else + MaterialTheme.colorScheme.error.copy(alpha = 0.5f) + ) + ) { + Icon( + if (result.isValid) Icons.Default.CheckCircle else Icons.Default.Warning, + contentDescription = null + ) + Spacer(modifier = Modifier.width(8.dp)) + Text( + if (result.isValid) + "Continue to Training (${result.validImagesWithFaces.size} images)" + else + "Fix ${result.validationErrors.size} Issue${if (result.validationErrors.size != 1) "s" else ""} to Continue" + ) + } + + if (!result.isValid) { + Spacer(modifier = Modifier.height(8.dp)) + Surface( + modifier = Modifier.fillMaxWidth(), + color = MaterialTheme.colorScheme.tertiaryContainer, + shape = RoundedCornerShape(8.dp) + ) { + Row( + modifier = Modifier.padding(12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + Icons.Default.Info, + contentDescription = null, + tint = MaterialTheme.colorScheme.onTertiaryContainer, + modifier = Modifier.size(20.dp) + ) + Spacer(modifier = Modifier.width(8.dp)) + Text( + text = "Tip: Use 'Replace' to swap problematic images, or 'Pick Face' to choose from group photos", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onTertiaryContainer + ) + } + } + } + } + } +} + +@Composable +private fun ProgressSummaryCard( + totalImages: Int, + validImages: Int, + requiredImages: Int, + isValid: Boolean +) { + Card( + modifier = Modifier.fillMaxWidth(), + colors = CardDefaults.cardColors( + containerColor = if (isValid) + MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.5f) + else + MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f) + ) + ) { + Column( + modifier = Modifier.padding(16.dp) + ) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Text( + text = "Progress", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Bold + ) + + Icon( + imageVector = if (isValid) Icons.Default.CheckCircle else Icons.Default.Warning, + contentDescription = null, + tint = if (isValid) + MaterialTheme.colorScheme.primary + else + MaterialTheme.colorScheme.error, + modifier = Modifier.size(32.dp) + ) + } + + Spacer(modifier = Modifier.height(12.dp)) + + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceEvenly + ) { + StatItem( + label = "Total", + value = totalImages.toString(), + color = MaterialTheme.colorScheme.onSurface + ) + StatItem( + label = "Valid", + value = validImages.toString(), + color = if (validImages >= requiredImages) + MaterialTheme.colorScheme.primary + else + MaterialTheme.colorScheme.error + ) + StatItem( + label = "Need", + value = requiredImages.toString(), + color = MaterialTheme.colorScheme.onSurface.copy(alpha = 0.6f) + ) + } + + Spacer(modifier = Modifier.height(12.dp)) + + LinearProgressIndicator( + progress = { (validImages.toFloat() / requiredImages.toFloat()).coerceIn(0f, 1f) }, + modifier = Modifier.fillMaxWidth(), + color = if (isValid) MaterialTheme.colorScheme.primary else MaterialTheme.colorScheme.error + ) + } + } +} + +@Composable +private fun StatItem(label: String, value: String, color: Color) { + Column(horizontalAlignment = Alignment.CenterHorizontally) { + Text( + text = value, + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold, + color = color + ) + Text( + text = label, + style = MaterialTheme.typography.bodySmall, + color = color.copy(alpha = 0.7f) + ) + } +} + +@Composable +private fun ImageResultCard( + index: Int, + result: FaceDetectionHelper.FaceDetectionResult, + onReplace: (Uri) -> Unit, + onSelectFace: (() -> Unit)? +) { + val photoPickerLauncher = rememberLauncherForActivityResult( + contract = ActivityResultContracts.PickVisualMedia() + ) { uri -> + uri?.let { onReplace(it) } + } + + val status = when { + result.errorMessage != null -> ImageStatus.ERROR + !result.hasFace -> ImageStatus.NO_FACE + result.faceCount > 1 -> ImageStatus.MULTIPLE_FACES + result.faceCount == 1 -> ImageStatus.VALID + else -> ImageStatus.ERROR + } + + Card( + modifier = Modifier.fillMaxWidth(), + colors = CardDefaults.cardColors( + containerColor = when (status) { + ImageStatus.VALID -> MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f) + ImageStatus.MULTIPLE_FACES -> MaterialTheme.colorScheme.tertiaryContainer.copy(alpha = 0.4f) + else -> MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f) + } + ) + ) { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(12.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(12.dp) + ) { + // Image Number Badge + Box( + modifier = Modifier + .size(40.dp) + .background( + color = when (status) { + ImageStatus.VALID -> MaterialTheme.colorScheme.primary + ImageStatus.MULTIPLE_FACES -> MaterialTheme.colorScheme.tertiary + else -> MaterialTheme.colorScheme.error + }, + shape = CircleShape + ), + contentAlignment = Alignment.Center + ) { + Text( + text = index.toString(), + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Bold, + color = Color.White + ) + } + + // Thumbnail + if (result.croppedFaceBitmap != null) { + Image( + bitmap = result.croppedFaceBitmap.asImageBitmap(), + contentDescription = "Face", + modifier = Modifier + .size(64.dp) + .clip(RoundedCornerShape(8.dp)) + .border( + BorderStroke( + 2.dp, + when (status) { + ImageStatus.VALID -> MaterialTheme.colorScheme.primary + ImageStatus.MULTIPLE_FACES -> MaterialTheme.colorScheme.tertiary + else -> MaterialTheme.colorScheme.error + } + ), + RoundedCornerShape(8.dp) + ), + contentScale = ContentScale.Crop + ) + } else { + AsyncImage( + model = result.uri, + contentDescription = "Original image", + modifier = Modifier + .size(64.dp) + .clip(RoundedCornerShape(8.dp)), + contentScale = ContentScale.Crop + ) + } + + // Status and Info + Column( + modifier = Modifier.weight(1f) + ) { + Row(verticalAlignment = Alignment.CenterVertically) { + Icon( + imageVector = when (status) { + ImageStatus.VALID -> Icons.Default.CheckCircle + ImageStatus.MULTIPLE_FACES -> Icons.Default.Info + else -> Icons.Default.Warning + }, + contentDescription = null, + tint = when (status) { + ImageStatus.VALID -> MaterialTheme.colorScheme.primary + ImageStatus.MULTIPLE_FACES -> MaterialTheme.colorScheme.tertiary + else -> MaterialTheme.colorScheme.error + }, + modifier = Modifier.size(20.dp) + ) + Spacer(modifier = Modifier.width(4.dp)) + Text( + text = when (status) { + ImageStatus.VALID -> "Face Detected" + ImageStatus.MULTIPLE_FACES -> "Multiple Faces (${result.faceCount})" + ImageStatus.NO_FACE -> "No Face Detected" + ImageStatus.ERROR -> "Error" + }, + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.SemiBold + ) + } + + Text( + text = result.uri.lastPathSegment ?: "Unknown", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + maxLines = 1 + ) + } + + // Action Buttons + Column( + horizontalAlignment = Alignment.End, + verticalArrangement = Arrangement.spacedBy(4.dp) + ) { + // Select Face button (for multiple faces) + if (onSelectFace != null) { + OutlinedButton( + onClick = onSelectFace, + modifier = Modifier.height(32.dp), + contentPadding = PaddingValues(horizontal = 12.dp, vertical = 0.dp), + colors = ButtonDefaults.outlinedButtonColors( + contentColor = MaterialTheme.colorScheme.tertiary + ), + border = BorderStroke(1.dp, MaterialTheme.colorScheme.tertiary) + ) { + Icon( + Icons.Default.Face, + contentDescription = null, + modifier = Modifier.size(16.dp) + ) + Spacer(modifier = Modifier.width(4.dp)) + Text("Pick Face", style = MaterialTheme.typography.bodySmall) + } + } + + // Replace button + OutlinedButton( + onClick = { + photoPickerLauncher.launch( + PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly) + ) + }, + modifier = Modifier.height(32.dp), + contentPadding = PaddingValues(horizontal = 12.dp, vertical = 0.dp) + ) { + Icon( + Icons.Default.Refresh, + contentDescription = null, + modifier = Modifier.size(16.dp) + ) + Spacer(modifier = Modifier.width(4.dp)) + Text("Replace", style = MaterialTheme.typography.bodySmall) + } + } + } + } +} + +@Composable +private fun ValidationIssuesCard(errors: List) { + Card( + modifier = Modifier.fillMaxWidth(), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f) + ) + ) { + Column( + modifier = Modifier.padding(16.dp), + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + Row(verticalAlignment = Alignment.CenterVertically) { + Icon( + Icons.Default.Warning, + contentDescription = null, + tint = MaterialTheme.colorScheme.error + ) + Spacer(modifier = Modifier.width(8.dp)) + Text( + text = "Issues Found (${errors.size})", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.error + ) + } + + Divider(color = MaterialTheme.colorScheme.error.copy(alpha = 0.3f)) + + errors.forEach { error -> + when (error) { + is TrainingSanityChecker.ValidationError.NoFaceDetected -> { + Text( + text = "• ${error.uris.size} image(s) without detected faces - use Replace button", + style = MaterialTheme.typography.bodyMedium + ) + } + is TrainingSanityChecker.ValidationError.MultipleFacesDetected -> { + Text( + text = "• ${error.uri.lastPathSegment} has ${error.faceCount} faces - use Pick Face button", + style = MaterialTheme.typography.bodyMedium + ) + } + is TrainingSanityChecker.ValidationError.DuplicateImages -> { + Text( + text = "• ${error.groups.size} duplicate image group(s) - replace duplicates", + style = MaterialTheme.typography.bodyMedium + ) + } + is TrainingSanityChecker.ValidationError.InsufficientImages -> { + Text( + text = "• Need ${error.required} valid images, currently have ${error.available}", + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.Bold + ) + } + is TrainingSanityChecker.ValidationError.ImageLoadError -> { + Text( + text = "• Failed to load ${error.uri.lastPathSegment} - use Replace button", + style = MaterialTheme.typography.bodyMedium + ) + } + } + } + } + } +} + +@Composable +private fun ErrorView( + message: String, + onRetry: () -> Unit ) { Column( modifier = Modifier @@ -32,65 +632,36 @@ fun ScanResultsScreen( horizontalAlignment = Alignment.CenterHorizontally, verticalArrangement = Arrangement.Center ) { - when (state) { - is ScanningState.Processing -> { - CircularProgressIndicator() - Spacer(Modifier.height(16.dp)) - Text("Analyzing faces... ${state.current} / ${state.total}") - } - is ScanningState.Success -> { - Text( - text = "Analysis Complete!", - style = MaterialTheme.typography.headlineMedium - ) - - LazyColumn( - modifier = Modifier - .weight(1f) - .padding(vertical = 16.dp) - ) { - // FIX: Ensure 'items' is the one that takes a List, not a count - items(state.results) { result -> - Row( - modifier = Modifier - .fillMaxWidth() - .padding(8.dp), - verticalAlignment = Alignment.CenterVertically - ) { - Image( - painter = rememberAsyncImagePainter(result.uri), - contentDescription = null, - modifier = Modifier - .size(64.dp) - .clip(RoundedCornerShape(8.dp)), - contentScale = ContentScale.Crop - ) - Spacer(Modifier.width(16.dp)) - Column { - Text(if (result.faceCount > 0) "✅ Face Detected" else "❌ No Face") - if (result.hasMultipleFaces) { - Text( - text = "⚠️ Multiple faces (${result.faceCount})", - color = MaterialTheme.colorScheme.error, - style = MaterialTheme.typography.bodySmall - ) - } - } - } - } - } - - Button( - onClick = onFinish, - modifier = Modifier.fillMaxWidth(), - shape = RoundedCornerShape(8.dp) - ) { - Text("Done") - } - } - // Add fallback for other states (Idle/RequiresCrop) - // so the compiler doesn't complain about non-exhaustive 'when' - else -> { } + Icon( + imageVector = Icons.Default.Close, + contentDescription = null, + modifier = Modifier.size(64.dp), + tint = MaterialTheme.colorScheme.error + ) + Spacer(modifier = Modifier.height(16.dp)) + Text( + text = "Error", + style = MaterialTheme.typography.titleLarge, + fontWeight = FontWeight.Bold + ) + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = message, + style = MaterialTheme.typography.bodyMedium, + textAlign = TextAlign.Center + ) + Spacer(modifier = Modifier.height(24.dp)) + Button(onClick = onRetry) { + Icon(Icons.Default.Refresh, contentDescription = null) + Spacer(modifier = Modifier.width(8.dp)) + Text("Try Again") } } +} + +private enum class ImageStatus { + VALID, + MULTIPLE_FACES, + NO_FACE, + ERROR } \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/TrainViewModel.kt b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/TrainViewModel.kt index 572712e..f987c8c 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/TrainViewModel.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/TrainViewModel.kt @@ -1,120 +1,253 @@ package com.placeholder.sherpai2.ui.trainingprep -import android.content.Context -import android.graphics.Rect +import android.app.Application +import android.graphics.Bitmap import android.net.Uri -import androidx.lifecycle.ViewModel +import androidx.lifecycle.AndroidViewModel import androidx.lifecycle.viewModelScope -import com.google.mlkit.vision.common.InputImage -import com.google.mlkit.vision.face.FaceDetection -import com.google.mlkit.vision.face.FaceDetectorOptions -import com.placeholder.sherpai2.domain.repository.ImageRepository -import com.placeholder.sherpai2.domain.repository.TaggingRepository import dagger.hilt.android.lifecycle.HiltViewModel -import dagger.hilt.android.qualifiers.ApplicationContext -import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow -import kotlinx.coroutines.flow.first import kotlinx.coroutines.launch -import kotlinx.coroutines.sync.Semaphore -import kotlinx.coroutines.sync.withPermit -import kotlinx.coroutines.tasks.await -import kotlinx.coroutines.withContext import javax.inject.Inject -// 1. DEFINE THESE AT TOP LEVEL (Outside the class) so the UI can see them sealed class ScanningState { object Idle : ScanningState() - data class Processing(val current: Int, val total: Int) : ScanningState() - data class RequiresCrop(val uri: Uri, val faceBoxes: List, val remainingUris: List) : ScanningState() - data class Success(val results: List) : ScanningState() + data class Processing(val progress: Int, val total: Int) : ScanningState() + data class Success( + val sanityCheckResult: TrainingSanityChecker.SanityCheckResult + ) : ScanningState() + data class Error(val message: String) : ScanningState() } -data class ScanResult( - val uri: Uri, - val faceCount: Int, - val hasMultipleFaces: Boolean = faceCount > 1 -) - @HiltViewModel class TrainViewModel @Inject constructor( - @ApplicationContext private val context: Context, - private val imageRepository: ImageRepository, - private val taggingRepository: TaggingRepository -) : ViewModel() { + application: Application +) : AndroidViewModel(application) { + + private val sanityChecker = TrainingSanityChecker(application) + private val faceDetectionHelper = FaceDetectionHelper(application) private val _uiState = MutableStateFlow(ScanningState.Idle) val uiState: StateFlow = _uiState.asStateFlow() - private val semaphore = Semaphore(2) - private val finalResults = mutableListOf() + // Keep track of current images for replacements + private var currentImageUris: List = emptyList() - fun scanAndTagFaces(uris: List) = viewModelScope.launch { - // Goal: Deduplicate by SHA256 before starting - val allImages = imageRepository.getAllImages().first() - val uriToShaMap = allImages.associate { it.image.imageUri to it.image.sha256 } + // Keep track of manual face selections (imageUri -> selectedFaceIndex) + private val manualFaceSelections = mutableMapOf() - val uniqueUris = uris.distinctBy { uri -> - uriToShaMap[uri.toString()] ?: uri.toString() - } + data class ManualFaceSelection( + val faceIndex: Int, + val croppedFaceBitmap: Bitmap + ) - processNext(uniqueUris) + /** + * Scan and validate images for training + */ + fun scanAndTagFaces(imageUris: List) { + currentImageUris = imageUris + manualFaceSelections.clear() + performScan(imageUris) } - private suspend fun processNext(remaining: List) { - if (remaining.isEmpty()) { - _uiState.value = ScanningState.Success(finalResults.toList()) - return + /** + * Replace a single image and re-scan + */ + fun replaceImage(oldUri: Uri, newUri: Uri) { + viewModelScope.launch { + val updatedUris = currentImageUris.toMutableList() + val index = updatedUris.indexOf(oldUri) + + if (index != -1) { + updatedUris[index] = newUri + currentImageUris = updatedUris + + // Remove manual selection for old URI if any + manualFaceSelections.remove(oldUri) + + // Re-scan all images + performScan(currentImageUris) + } + } + } + + /** + * User manually selected a face from a multi-face image + */ + fun selectFaceFromImage(imageUri: Uri, faceIndex: Int, croppedFaceBitmap: Bitmap) { + manualFaceSelections[imageUri] = ManualFaceSelection(faceIndex, croppedFaceBitmap) + + // Re-process the results with the manual selection + val currentState = _uiState.value + if (currentState is ScanningState.Success) { + val updatedResult = applyManualSelections(currentState.sanityCheckResult) + _uiState.value = ScanningState.Success(updatedResult) + } + } + + /** + * Perform the actual scanning + */ + private fun performScan(imageUris: List) { + viewModelScope.launch { + try { + _uiState.value = ScanningState.Processing(0, imageUris.size) + + // Perform sanity checks + val result = sanityChecker.performSanityChecks( + imageUris = imageUris, + minImagesRequired = 10, + allowMultipleFaces = true, // Allow multiple faces - user can pick + duplicateSimilarityThreshold = 0.95 + ) + + // Apply any manual face selections + val finalResult = applyManualSelections(result) + + _uiState.value = ScanningState.Success(finalResult) + + } catch (e: Exception) { + _uiState.value = ScanningState.Error( + e.message ?: "An unknown error occurred" + ) + } + } + } + + /** + * Apply manual face selections to the results + */ + private fun applyManualSelections( + result: TrainingSanityChecker.SanityCheckResult + ): TrainingSanityChecker.SanityCheckResult { + + // If no manual selections, return original + if (manualFaceSelections.isEmpty()) { + return result } - val currentUri = remaining.first() - val nextList = remaining.drop(1) - - _uiState.value = ScanningState.Processing(finalResults.size + 1, finalResults.size + remaining.size) - - val detector = FaceDetection.getClient(faceOptions()) - try { - val image = InputImage.fromFilePath(context, currentUri) - val faces = detector.process(image).await() - - if (faces.size > 1) { - // FORCE USER TO CROP: Transition to RequiresCrop state - _uiState.value = ScanningState.RequiresCrop( - uri = currentUri, - faceBoxes = faces.map { it.boundingBox }, - remainingUris = nextList + // Update face detection results with manual selections + val updatedFaceResults = result.faceDetectionResults.map { faceResult -> + val manualSelection = manualFaceSelections[faceResult.uri] + if (manualSelection != null) { + // Replace the cropped face with the manually selected one + faceResult.copy( + croppedFaceBitmap = manualSelection.croppedFaceBitmap, + // Treat as single face since user selected one + faceCount = 1 ) } else { - val faceCount = faces.size - if (faceCount > 0) tagImage(currentUri) - - finalResults.add(ScanResult(currentUri, faceCount)) - processNext(nextList) + faceResult } - } catch (e: Exception) { - processNext(nextList) - } finally { - detector.close() } - } - private suspend fun tagImage(uri: Uri) { - val allImages = imageRepository.getAllImages().first() - val imageId = allImages.find { it.image.imageUri == uri.toString() }?.image?.imageId - if (imageId != null) { - taggingRepository.addTagToImage(imageId, "face", "ML_KIT", 1.0f) + // Update valid images list + val updatedValidImages = updatedFaceResults + .filter { it.hasFace } + .filter { it.croppedFaceBitmap != null } + .filter { it.errorMessage == null } + .filter { it.faceCount >= 1 } // Now accept if user picked a face + .map { result -> + TrainingSanityChecker.ValidTrainingImage( + uri = result.uri, + croppedFaceBitmap = result.croppedFaceBitmap!!, + faceCount = result.faceCount + ) + } + + // Recalculate validation errors + val updatedErrors = result.validationErrors.toMutableList() + + // Remove multiple face errors for images with manual selections + updatedErrors.removeAll { error -> + error is TrainingSanityChecker.ValidationError.MultipleFacesDetected && + manualFaceSelections.containsKey(error.uri) } + + // Check if we have enough valid images now + if (updatedValidImages.size < 10) { + if (updatedErrors.none { it is TrainingSanityChecker.ValidationError.InsufficientImages }) { + updatedErrors.add( + TrainingSanityChecker.ValidationError.InsufficientImages( + required = 10, + available = updatedValidImages.size + ) + ) + } + } else { + // Remove insufficient images error if we now have enough + updatedErrors.removeAll { it is TrainingSanityChecker.ValidationError.InsufficientImages } + } + + val isValid = updatedErrors.isEmpty() && updatedValidImages.size >= 10 + + return result.copy( + isValid = isValid, + faceDetectionResults = updatedFaceResults, + validationErrors = updatedErrors, + validImagesWithFaces = updatedValidImages + ) } - fun onFaceSelected(uri: Uri, box: Rect, remaining: List) = viewModelScope.launch { - tagImage(uri) - finalResults.add(ScanResult(uri, 1)) - processNext(remaining) + /** + * Get formatted error messages + */ + fun getFormattedErrors(result: TrainingSanityChecker.SanityCheckResult): List { + return sanityChecker.formatValidationErrors(result.validationErrors) } - private fun faceOptions() = FaceDetectorOptions.Builder() - .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) - .build() + /** + * Reset to idle state + */ + fun reset() { + _uiState.value = ScanningState.Idle + currentImageUris = emptyList() + manualFaceSelections.clear() + } + + override fun onCleared() { + super.onCleared() + sanityChecker.cleanup() + faceDetectionHelper.cleanup() + } +} + +// Extension function to copy FaceDetectionResult with modifications +private fun FaceDetectionHelper.FaceDetectionResult.copy( + uri: Uri = this.uri, + hasFace: Boolean = this.hasFace, + faceCount: Int = this.faceCount, + faceBounds: List = this.faceBounds, + croppedFaceBitmap: Bitmap? = this.croppedFaceBitmap, + errorMessage: String? = this.errorMessage +): FaceDetectionHelper.FaceDetectionResult { + return FaceDetectionHelper.FaceDetectionResult( + uri = uri, + hasFace = hasFace, + faceCount = faceCount, + faceBounds = faceBounds, + croppedFaceBitmap = croppedFaceBitmap, + errorMessage = errorMessage + ) +} + +// Extension function to copy SanityCheckResult with modifications +private fun TrainingSanityChecker.SanityCheckResult.copy( + isValid: Boolean = this.isValid, + faceDetectionResults: List = this.faceDetectionResults, + duplicateCheckResult: DuplicateImageDetector.DuplicateCheckResult = this.duplicateCheckResult, + validationErrors: List = this.validationErrors, + warnings: List = this.warnings, + validImagesWithFaces: List = this.validImagesWithFaces +): TrainingSanityChecker.SanityCheckResult { + return TrainingSanityChecker.SanityCheckResult( + isValid = isValid, + faceDetectionResults = faceDetectionResults, + duplicateCheckResult = duplicateCheckResult, + validationErrors = validationErrors, + warnings = warnings, + validImagesWithFaces = validImagesWithFaces + ) } \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingsanitychecker.kt b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingsanitychecker.kt new file mode 100644 index 0000000..9520b67 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingsanitychecker.kt @@ -0,0 +1,188 @@ +package com.placeholder.sherpai2.ui.trainingprep + +import android.content.Context +import android.graphics.Bitmap +import android.net.Uri + +/** + * Coordinates sanity checks for training images + */ +class TrainingSanityChecker(private val context: Context) { + + private val faceDetectionHelper = FaceDetectionHelper(context) + private val duplicateDetector = DuplicateImageDetector(context) + + data class SanityCheckResult( + val isValid: Boolean, + val faceDetectionResults: List, + val duplicateCheckResult: DuplicateImageDetector.DuplicateCheckResult, + val validationErrors: List, + val warnings: List, + val validImagesWithFaces: List + ) + + data class ValidTrainingImage( + val uri: Uri, + val croppedFaceBitmap: Bitmap, + val faceCount: Int + ) + + sealed class ValidationError { + data class NoFaceDetected(val uris: List) : ValidationError() + data class MultipleFacesDetected(val uri: Uri, val faceCount: Int) : ValidationError() + data class DuplicateImages(val groups: List) : ValidationError() + data class InsufficientImages(val required: Int, val available: Int) : ValidationError() + data class ImageLoadError(val uri: Uri, val error: String) : ValidationError() + } + + /** + * Perform comprehensive sanity checks on training images + */ + suspend fun performSanityChecks( + imageUris: List, + minImagesRequired: Int = 10, + allowMultipleFaces: Boolean = false, + duplicateSimilarityThreshold: Double = 0.95 + ): SanityCheckResult { + + val validationErrors = mutableListOf() + val warnings = mutableListOf() + + // Check minimum image count + if (imageUris.size < minImagesRequired) { + validationErrors.add( + ValidationError.InsufficientImages( + required = minImagesRequired, + available = imageUris.size + ) + ) + } + + // Step 1: Detect faces in all images + val faceDetectionResults = faceDetectionHelper.detectFacesInImages(imageUris) + + // Check for images without faces + val imagesWithoutFaces = faceDetectionResults.filter { !it.hasFace } + if (imagesWithoutFaces.isNotEmpty()) { + validationErrors.add( + ValidationError.NoFaceDetected( + uris = imagesWithoutFaces.map { it.uri } + ) + ) + } + + // Check for images with errors + faceDetectionResults.filter { it.errorMessage != null }.forEach { result -> + validationErrors.add( + ValidationError.ImageLoadError( + uri = result.uri, + error = result.errorMessage ?: "Unknown error" + ) + ) + } + + // Check for images with multiple faces + if (!allowMultipleFaces) { + faceDetectionResults.filter { it.faceCount > 1 }.forEach { result -> + validationErrors.add( + ValidationError.MultipleFacesDetected( + uri = result.uri, + faceCount = result.faceCount + ) + ) + } + } else { + faceDetectionResults.filter { it.faceCount > 1 }.forEach { result -> + warnings.add("Image ${result.uri.lastPathSegment} contains ${result.faceCount} faces. Using the largest detected face.") + } + } + + // Step 2: Check for duplicate images + val duplicateCheckResult = duplicateDetector.checkForDuplicates( + uris = imageUris, + similarityThreshold = duplicateSimilarityThreshold + ) + + if (duplicateCheckResult.hasDuplicates) { + validationErrors.add( + ValidationError.DuplicateImages( + groups = duplicateCheckResult.duplicateGroups + ) + ) + } + + // Step 3: Create list of valid training images + val validImagesWithFaces = faceDetectionResults + .filter { it.hasFace && it.croppedFaceBitmap != null } + .filter { allowMultipleFaces || it.faceCount == 1 } + .map { result -> + ValidTrainingImage( + uri = result.uri, + croppedFaceBitmap = result.croppedFaceBitmap!!, + faceCount = result.faceCount + ) + } + + // Check if we have enough valid images after all checks + if (validImagesWithFaces.size < minImagesRequired) { + val existingError = validationErrors.find { it is ValidationError.InsufficientImages } + if (existingError == null) { + validationErrors.add( + ValidationError.InsufficientImages( + required = minImagesRequired, + available = validImagesWithFaces.size + ) + ) + } + } + + val isValid = validationErrors.isEmpty() && validImagesWithFaces.size >= minImagesRequired + + return SanityCheckResult( + isValid = isValid, + faceDetectionResults = faceDetectionResults, + duplicateCheckResult = duplicateCheckResult, + validationErrors = validationErrors, + warnings = warnings, + validImagesWithFaces = validImagesWithFaces + ) + } + + /** + * Format validation errors into human-readable messages + */ + fun formatValidationErrors(errors: List): List { + return errors.map { error -> + when (error) { + is ValidationError.NoFaceDetected -> { + val count = error.uris.size + val images = error.uris.joinToString(", ") { it.lastPathSegment ?: "Unknown" } + "No face detected in $count image(s): $images" + } + is ValidationError.MultipleFacesDetected -> { + "Multiple faces (${error.faceCount}) detected in: ${error.uri.lastPathSegment}" + } + is ValidationError.DuplicateImages -> { + val count = error.groups.size + val details = error.groups.joinToString("\n") { group -> + " - ${group.images.size} duplicates: ${group.images.joinToString(", ") { it.lastPathSegment ?: "Unknown" }}" + } + "Found $count duplicate group(s):\n$details" + } + is ValidationError.InsufficientImages -> { + "Insufficient images: need ${error.required}, but only ${error.available} valid images available" + } + is ValidationError.ImageLoadError -> { + "Failed to load image ${error.uri.lastPathSegment}: ${error.error}" + } + } + } + } + + /** + * Clean up resources + */ + fun cleanup() { + faceDetectionHelper.cleanup() + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingsanityviewmodel.kt b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingsanityviewmodel.kt new file mode 100644 index 0000000..ca5ec50 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingsanityviewmodel.kt @@ -0,0 +1,78 @@ +package com.placeholder.sherpai2.ui.trainingprep + +import android.app.Application +import android.net.Uri +import androidx.lifecycle.AndroidViewModel +import androidx.lifecycle.viewModelScope +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.launch + +/** + * ViewModel for managing training image sanity checks + */ +class TrainingSanityViewModel(application: Application) : AndroidViewModel(application) { + + private val sanityChecker = TrainingSanityChecker(application) + + private val _uiState = MutableStateFlow(TrainingSanityUiState.Idle) + val uiState: StateFlow = _uiState.asStateFlow() + + sealed class TrainingSanityUiState { + object Idle : TrainingSanityUiState() + object Checking : TrainingSanityUiState() + data class Success( + val result: TrainingSanityChecker.SanityCheckResult + ) : TrainingSanityUiState() + data class Error(val message: String) : TrainingSanityUiState() + } + + /** + * Perform sanity checks on selected images + */ + fun checkImages( + imageUris: List, + minImagesRequired: Int = 10, + allowMultipleFaces: Boolean = false, + duplicateSimilarityThreshold: Double = 0.95 + ) { + viewModelScope.launch { + try { + _uiState.value = TrainingSanityUiState.Checking + + val result = sanityChecker.performSanityChecks( + imageUris = imageUris, + minImagesRequired = minImagesRequired, + allowMultipleFaces = allowMultipleFaces, + duplicateSimilarityThreshold = duplicateSimilarityThreshold + ) + + _uiState.value = TrainingSanityUiState.Success(result) + } catch (e: Exception) { + _uiState.value = TrainingSanityUiState.Error( + e.message ?: "An unknown error occurred during sanity checks" + ) + } + } + } + + /** + * Reset the UI state + */ + fun resetState() { + _uiState.value = TrainingSanityUiState.Idle + } + + /** + * Get formatted error messages from validation result + */ + fun getFormattedErrors(result: TrainingSanityChecker.SanityCheckResult): List { + return sanityChecker.formatValidationErrors(result.validationErrors) + } + + override fun onCleared() { + super.onCleared() + sanityChecker.cleanup() + } +} \ No newline at end of file