TrainScreen / FacePicker / Sanity Checking input training data (dupes, multi faces)

This commit is contained in:
genki
2026-01-02 02:20:57 -05:00
parent 22c25d5ced
commit 6734c343cc
7 changed files with 1836 additions and 148 deletions

View File

@@ -0,0 +1,159 @@
package com.placeholder.sherpai2.ui.trainingprep
import android.content.Context
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.net.Uri
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.InputStream
/**
* Helper class for detecting duplicate or near-duplicate images using perceptual hashing
*/
class DuplicateImageDetector(private val context: Context) {
data class DuplicateCheckResult(
val hasDuplicates: Boolean,
val duplicateGroups: List<DuplicateGroup>,
val uniqueImageCount: Int
)
data class DuplicateGroup(
val images: List<Uri>,
val similarity: Double
)
private data class ImageHash(
val uri: Uri,
val hash: Long
)
/**
* Check for duplicate images in the provided list
*/
suspend fun checkForDuplicates(
uris: List<Uri>,
similarityThreshold: Double = 0.95
): DuplicateCheckResult = withContext(Dispatchers.Default) {
if (uris.size < 2) {
return@withContext DuplicateCheckResult(
hasDuplicates = false,
duplicateGroups = emptyList(),
uniqueImageCount = uris.size
)
}
// Compute perceptual hash for each image
val imageHashes = uris.mapNotNull { uri ->
try {
val bitmap = loadBitmap(uri)
bitmap?.let {
val hash = computePerceptualHash(it)
ImageHash(uri, hash)
}
} catch (e: Exception) {
null
}
}
// Find duplicate groups
val duplicateGroups = mutableListOf<DuplicateGroup>()
val processed = mutableSetOf<Uri>()
for (i in imageHashes.indices) {
if (imageHashes[i].uri in processed) continue
val currentGroup = mutableListOf(imageHashes[i].uri)
for (j in i + 1 until imageHashes.size) {
if (imageHashes[j].uri in processed) continue
val similarity = calculateSimilarity(imageHashes[i].hash, imageHashes[j].hash)
if (similarity >= similarityThreshold) {
currentGroup.add(imageHashes[j].uri)
processed.add(imageHashes[j].uri)
}
}
if (currentGroup.size > 1) {
duplicateGroups.add(
DuplicateGroup(
images = currentGroup,
similarity = 1.0
)
)
processed.addAll(currentGroup)
}
}
DuplicateCheckResult(
hasDuplicates = duplicateGroups.isNotEmpty(),
duplicateGroups = duplicateGroups,
uniqueImageCount = uris.size - duplicateGroups.sumOf { it.images.size - 1 }
)
}
/**
* Compute perceptual hash using difference hash (dHash) algorithm
*/
private fun computePerceptualHash(bitmap: Bitmap): Long {
// Resize to 9x8
val resized = Bitmap.createScaledBitmap(bitmap, 9, 8, false)
var hash = 0L
var bitIndex = 0
for (y in 0 until 8) {
for (x in 0 until 8) {
val leftPixel = resized.getPixel(x, y)
val rightPixel = resized.getPixel(x + 1, y)
val leftGray = toGrayscale(leftPixel)
val rightGray = toGrayscale(rightPixel)
if (leftGray > rightGray) {
hash = hash or (1L shl bitIndex)
}
bitIndex++
}
}
resized.recycle()
return hash
}
/**
* Convert RGB pixel to grayscale value
*/
private fun toGrayscale(pixel: Int): Int {
val r = (pixel shr 16) and 0xFF
val g = (pixel shr 8) and 0xFF
val b = pixel and 0xFF
return (0.299 * r + 0.587 * g + 0.114 * b).toInt()
}
/**
* Calculate similarity between two hashes
*/
private fun calculateSimilarity(hash1: Long, hash2: Long): Double {
val xor = hash1 xor hash2
val hammingDistance = xor.countOneBits()
return 1.0 - (hammingDistance / 64.0)
}
/**
* Load bitmap from URI
*/
private fun loadBitmap(uri: Uri): Bitmap? {
return try {
val inputStream: InputStream? = context.contentResolver.openInputStream(uri)
BitmapFactory.decodeStream(inputStream)?.also {
inputStream?.close()
}
} catch (e: Exception) {
null
}
}
}

View File

@@ -0,0 +1,435 @@
package com.placeholder.sherpai2.ui.trainingprep
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Rect
import android.net.Uri
import androidx.compose.foundation.BorderStroke
import androidx.compose.foundation.Canvas
import androidx.compose.foundation.background
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.*
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.CheckCircle
import androidx.compose.material.icons.filled.Close
import androidx.compose.material3.*
import androidx.compose.runtime.*
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.geometry.Offset
import androidx.compose.ui.geometry.Size
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.graphics.drawscope.Stroke
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import androidx.compose.ui.window.DialogProperties
import coil.compose.AsyncImage
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
/**
* Dialog for selecting a face from multiple detected faces
*/
@Composable
fun FacePickerDialog(
result: FaceDetectionHelper.FaceDetectionResult,
onDismiss: () -> Unit,
onFaceSelected: (Int, Bitmap) -> Unit // faceIndex, croppedFaceBitmap
) {
val context = LocalContext.current
var selectedFaceIndex by remember { mutableStateOf<Int?>(null) }
var croppedFaces by remember { mutableStateOf<List<Bitmap>>(emptyList()) }
var isLoading by remember { mutableStateOf(true) }
// Load and crop all faces
LaunchedEffect(result) {
isLoading = true
croppedFaces = withContext(Dispatchers.IO) {
val bitmap = loadBitmapFromUri(context, result.uri)
bitmap?.let { bmp ->
result.faceBounds.map { bounds ->
cropFaceFromBitmap(bmp, bounds)
}
} ?: emptyList()
}
isLoading = false
// Auto-select the first (largest) face
if (croppedFaces.isNotEmpty()) {
selectedFaceIndex = 0
}
}
Dialog(
onDismissRequest = onDismiss,
properties = DialogProperties(usePlatformDefaultWidth = false)
) {
Card(
modifier = Modifier
.fillMaxWidth(0.95f)
.fillMaxHeight(0.9f),
shape = RoundedCornerShape(16.dp)
) {
Column(
modifier = Modifier
.fillMaxSize()
.padding(20.dp),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Header
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.SpaceBetween,
verticalAlignment = Alignment.CenterVertically
) {
Column {
Text(
text = "Pick a Face",
style = MaterialTheme.typography.headlineSmall,
fontWeight = FontWeight.Bold
)
Text(
text = "${result.faceCount} faces detected",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
IconButton(onClick = onDismiss) {
Icon(Icons.Default.Close, "Close")
}
}
// Instruction
Text(
text = "Tap a face below to select it for training:",
style = MaterialTheme.typography.bodyMedium
)
if (isLoading) {
// Loading state
Box(
modifier = Modifier
.fillMaxWidth()
.weight(1f),
contentAlignment = Alignment.Center
) {
CircularProgressIndicator()
}
} else {
// Original image with face boxes overlay
Card(
modifier = Modifier
.fillMaxWidth()
.weight(1f),
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.surfaceVariant
)
) {
Box(
modifier = Modifier.fillMaxSize(),
contentAlignment = Alignment.Center
) {
FaceOverlayImage(
imageUri = result.uri,
faceBounds = result.faceBounds,
selectedFaceIndex = selectedFaceIndex,
onFaceClick = { index ->
selectedFaceIndex = index
}
)
}
}
// Face previews grid
Text(
text = "Preview (tap to select):",
style = MaterialTheme.typography.bodyMedium,
fontWeight = FontWeight.SemiBold
)
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
croppedFaces.forEachIndexed { index, faceBitmap ->
FacePreviewCard(
faceBitmap = faceBitmap,
index = index,
isSelected = selectedFaceIndex == index,
onClick = { selectedFaceIndex = index },
modifier = Modifier.weight(1f)
)
}
}
}
// Action buttons
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
OutlinedButton(
onClick = onDismiss,
modifier = Modifier.weight(1f)
) {
Text("Cancel")
}
Button(
onClick = {
selectedFaceIndex?.let { index ->
if (index < croppedFaces.size) {
onFaceSelected(index, croppedFaces[index])
}
}
},
modifier = Modifier.weight(1f),
enabled = selectedFaceIndex != null && !isLoading
) {
Icon(Icons.Default.CheckCircle, contentDescription = null)
Spacer(modifier = Modifier.width(8.dp))
Text("Use This Face")
}
}
}
}
}
}
/**
* Image with interactive face boxes overlay
*/
@Composable
private fun FaceOverlayImage(
imageUri: Uri,
faceBounds: List<Rect>,
selectedFaceIndex: Int?,
onFaceClick: (Int) -> Unit
) {
var imageSize by remember { mutableStateOf(Size.Zero) }
var imageBounds by remember { mutableStateOf(Rect()) }
Box(
modifier = Modifier.fillMaxSize()
) {
// Original image
AsyncImage(
model = imageUri,
contentDescription = "Original image",
modifier = Modifier
.fillMaxSize()
.padding(8.dp),
contentScale = ContentScale.Fit,
onSuccess = { state ->
val drawable = state.result.drawable
imageBounds = Rect(0, 0, drawable.intrinsicWidth, drawable.intrinsicHeight)
}
)
// Face boxes overlay
Canvas(
modifier = Modifier
.fillMaxSize()
.padding(8.dp)
) {
if (imageBounds.width() > 0 && imageBounds.height() > 0) {
// Calculate scale to fit image in canvas
val scaleX = size.width / imageBounds.width()
val scaleY = size.height / imageBounds.height()
val scale = minOf(scaleX, scaleY)
// Calculate offset to center image
val scaledWidth = imageBounds.width() * scale
val scaledHeight = imageBounds.height() * scale
val offsetX = (size.width - scaledWidth) / 2
val offsetY = (size.height - scaledHeight) / 2
faceBounds.forEachIndexed { index, bounds ->
val isSelected = selectedFaceIndex == index
// Scale and position the face box
val left = bounds.left * scale + offsetX
val top = bounds.top * scale + offsetY
val width = bounds.width() * scale
val height = bounds.height() * scale
// Draw box
drawRect(
color = if (isSelected) Color(0xFF4CAF50) else Color(0xFF2196F3),
topLeft = Offset(left, top),
size = Size(width, height),
style = Stroke(width = if (isSelected) 6f else 4f)
)
// Draw semi-transparent fill for selected
if (isSelected) {
drawRect(
color = Color(0xFF4CAF50).copy(alpha = 0.2f),
topLeft = Offset(left, top),
size = Size(width, height)
)
}
// Draw face number label
drawCircle(
color = if (isSelected) Color(0xFF4CAF50) else Color(0xFF2196F3),
radius = 20f * scale,
center = Offset(left + 20f * scale, top + 20f * scale)
)
}
}
}
// Clickable areas for each face
faceBounds.forEachIndexed { index, bounds ->
if (imageBounds.width() > 0 && imageBounds.height() > 0) {
val scaleX = imageSize.width / imageBounds.width()
val scaleY = imageSize.height / imageBounds.height()
val scale = minOf(scaleX, scaleY)
val scaledWidth = imageBounds.width() * scale
val scaledHeight = imageBounds.height() * scale
val offsetX = (imageSize.width - scaledWidth) / 2
val offsetY = (imageSize.height - scaledHeight) / 2
Box(
modifier = Modifier
.fillMaxSize()
.clickable { onFaceClick(index) }
)
}
}
}
// Update image size
BoxWithConstraints {
LaunchedEffect(constraints) {
imageSize = Size(constraints.maxWidth.toFloat(), constraints.maxHeight.toFloat())
}
}
}
/**
* Individual face preview card
*/
@Composable
private fun FacePreviewCard(
faceBitmap: Bitmap,
index: Int,
isSelected: Boolean,
onClick: () -> Unit,
modifier: Modifier = Modifier
) {
Card(
modifier = modifier
.aspectRatio(1f)
.clickable(onClick = onClick),
colors = CardDefaults.cardColors(
containerColor = if (isSelected)
MaterialTheme.colorScheme.primaryContainer
else
MaterialTheme.colorScheme.surface
),
border = if (isSelected)
BorderStroke(3.dp, MaterialTheme.colorScheme.primary)
else
BorderStroke(1.dp, MaterialTheme.colorScheme.outline)
) {
Box(
modifier = Modifier.fillMaxSize()
) {
androidx.compose.foundation.Image(
bitmap = faceBitmap.asImageBitmap(),
contentDescription = "Face ${index + 1}",
modifier = Modifier.fillMaxSize(),
contentScale = ContentScale.Crop
)
// Selected checkmark (only show when selected)
if (isSelected) {
Surface(
modifier = Modifier
.align(Alignment.Center),
shape = CircleShape,
color = MaterialTheme.colorScheme.primary.copy(alpha = 0.9f)
) {
Icon(
Icons.Default.CheckCircle,
contentDescription = "Selected",
modifier = Modifier
.padding(12.dp)
.size(32.dp),
tint = MaterialTheme.colorScheme.onPrimary
)
}
}
// Face number badge (always in top-right, small)
Surface(
modifier = Modifier
.align(Alignment.TopEnd)
.padding(4.dp),
shape = CircleShape,
color = if (isSelected)
MaterialTheme.colorScheme.primary
else
MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.9f),
shadowElevation = 2.dp
) {
Text(
text = "${index + 1}",
modifier = Modifier.padding(6.dp),
style = MaterialTheme.typography.labelSmall,
fontWeight = FontWeight.Bold,
color = if (isSelected)
MaterialTheme.colorScheme.onPrimary
else
MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
}
/**
* Helper function to load bitmap from URI
*/
private suspend fun loadBitmapFromUri(
context: android.content.Context,
uri: Uri
): Bitmap? = withContext(Dispatchers.IO) {
try {
val inputStream = context.contentResolver.openInputStream(uri)
BitmapFactory.decodeStream(inputStream)?.also {
inputStream?.close()
}
} catch (e: Exception) {
null
}
}
/**
* Helper function to crop face from bitmap
*/
private fun cropFaceFromBitmap(bitmap: Bitmap, faceBounds: Rect): Bitmap {
// Add 20% padding around the face
val padding = (faceBounds.width() * 0.2f).toInt()
val left = (faceBounds.left - padding).coerceAtLeast(0)
val top = (faceBounds.top - padding).coerceAtLeast(0)
val right = (faceBounds.right + padding).coerceAtMost(bitmap.width)
val bottom = (faceBounds.bottom + padding).coerceAtMost(bitmap.height)
val width = right - left
val height = bottom - top
return Bitmap.createBitmap(bitmap, left, top, width, height)
}

View File

@@ -0,0 +1,124 @@
package com.placeholder.sherpai2.ui.trainingprep
import android.content.Context
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Rect
import android.net.Uri
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import kotlinx.coroutines.tasks.await
import java.io.InputStream
/**
* Helper class for detecting faces in images using ML Kit Face Detection
*/
class FaceDetectionHelper(private val context: Context) {
private val faceDetectorOptions = FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL)
.setClassificationMode(FaceDetectorOptions.CLASSIFICATION_MODE_ALL)
.setMinFaceSize(0.15f) // Detect faces that are at least 15% of image
.build()
private val detector = FaceDetection.getClient(faceDetectorOptions)
data class FaceDetectionResult(
val uri: Uri,
val hasFace: Boolean,
val faceCount: Int,
val faceBounds: List<Rect> = emptyList(),
val croppedFaceBitmap: Bitmap? = null,
val errorMessage: String? = null
)
/**
* Detect faces in a single image
*/
suspend fun detectFacesInImage(uri: Uri): FaceDetectionResult {
return try {
val bitmap = loadBitmap(uri)
if (bitmap == null) {
return FaceDetectionResult(
uri = uri,
hasFace = false,
faceCount = 0,
errorMessage = "Failed to load image"
)
}
val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = detector.process(inputImage).await()
val croppedFace = if (faces.isNotEmpty()) {
// Crop the first detected face with some padding
cropFaceFromBitmap(bitmap, faces[0].boundingBox)
} else null
FaceDetectionResult(
uri = uri,
hasFace = faces.isNotEmpty(),
faceCount = faces.size,
faceBounds = faces.map { it.boundingBox },
croppedFaceBitmap = croppedFace
)
} catch (e: Exception) {
FaceDetectionResult(
uri = uri,
hasFace = false,
faceCount = 0,
errorMessage = e.message ?: "Unknown error"
)
}
}
/**
* Detect faces in multiple images
*/
suspend fun detectFacesInImages(uris: List<Uri>): List<FaceDetectionResult> {
return uris.map { uri ->
detectFacesInImage(uri)
}
}
/**
* Crop face from bitmap with padding
*/
private fun cropFaceFromBitmap(bitmap: Bitmap, faceBounds: Rect): Bitmap {
// Add 20% padding around the face
val padding = (faceBounds.width() * 0.2f).toInt()
val left = (faceBounds.left - padding).coerceAtLeast(0)
val top = (faceBounds.top - padding).coerceAtLeast(0)
val right = (faceBounds.right + padding).coerceAtMost(bitmap.width)
val bottom = (faceBounds.bottom + padding).coerceAtMost(bitmap.height)
val width = right - left
val height = bottom - top
return Bitmap.createBitmap(bitmap, left, top, width, height)
}
/**
* Load bitmap from URI
*/
private fun loadBitmap(uri: Uri): Bitmap? {
return try {
val inputStream: InputStream? = context.contentResolver.openInputStream(uri)
BitmapFactory.decodeStream(inputStream)?.also {
inputStream?.close()
}
} catch (e: Exception) {
null
}
}
/**
* Clean up resources
*/
fun cleanup() {
detector.close()
}
}

View File

@@ -1,29 +1,629 @@
package com.placeholder.sherpai2.ui.trainingprep
import android.net.Uri
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.PickVisualMediaRequest
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.BorderStroke
import androidx.compose.foundation.Image
import androidx.compose.foundation.background
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items // CRITICAL: This is the correct import for List items
import androidx.compose.foundation.lazy.itemsIndexed
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.*
import androidx.compose.material3.*
import androidx.compose.runtime.Composable
import androidx.compose.runtime.*
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp
import coil.compose.rememberAsyncImagePainter
import androidx.compose.foundation.lazy.items
import androidx.hilt.navigation.compose.hiltViewModel
import coil.compose.AsyncImage
/**
* Displays the outcome of the face detection process.
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ScanResultsScreen(
state: ScanningState,
onFinish: () -> Unit
onFinish: () -> Unit,
trainViewModel: TrainViewModel = hiltViewModel()
) {
var showFacePickerDialog by remember { mutableStateOf<FaceDetectionHelper.FaceDetectionResult?>(null) }
Scaffold(
topBar = {
TopAppBar(
title = { Text("Training Image Analysis") },
colors = TopAppBarDefaults.topAppBarColors(
containerColor = MaterialTheme.colorScheme.primaryContainer
)
)
}
) { paddingValues ->
Box(
modifier = Modifier
.fillMaxSize()
.padding(paddingValues)
) {
when (state) {
is ScanningState.Idle -> {
// Should not happen
}
is ScanningState.Processing -> {
ProcessingView(
progress = state.progress,
total = state.total
)
}
is ScanningState.Success -> {
ImprovedResultsView(
result = state.sanityCheckResult,
onContinue = onFinish,
onRetry = onFinish,
onReplaceImage = { oldUri, newUri ->
trainViewModel.replaceImage(oldUri, newUri)
},
onSelectFaceFromMultiple = { result ->
showFacePickerDialog = result
}
)
}
is ScanningState.Error -> {
ErrorView(
message = state.message,
onRetry = onFinish
)
}
}
}
}
// Face Picker Dialog
showFacePickerDialog?.let { result ->
FacePickerDialog(
result = result,
onDismiss = { showFacePickerDialog = null },
onFaceSelected = { faceIndex, croppedFaceBitmap ->
trainViewModel.selectFaceFromImage(result.uri, faceIndex, croppedFaceBitmap)
showFacePickerDialog = null
}
)
}
}
@Composable
private fun ProcessingView(progress: Int, total: Int) {
Column(
modifier = Modifier.fillMaxSize(),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
CircularProgressIndicator(
modifier = Modifier.size(64.dp),
strokeWidth = 6.dp
)
Spacer(modifier = Modifier.height(24.dp))
Text(
text = "Analyzing images...",
style = MaterialTheme.typography.titleMedium
)
Spacer(modifier = Modifier.height(8.dp))
Text(
text = "Detecting faces and checking for duplicates",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
if (total > 0) {
Spacer(modifier = Modifier.height(16.dp))
LinearProgressIndicator(
progress = { (progress.toFloat() / total.toFloat()).coerceIn(0f, 1f) },
modifier = Modifier.width(200.dp)
)
Text(
text = "$progress / $total",
style = MaterialTheme.typography.bodySmall
)
}
}
}
@Composable
private fun ImprovedResultsView(
result: TrainingSanityChecker.SanityCheckResult,
onContinue: () -> Unit,
onRetry: () -> Unit,
onReplaceImage: (Uri, Uri) -> Unit,
onSelectFaceFromMultiple: (FaceDetectionHelper.FaceDetectionResult) -> Unit
) {
LazyColumn(
modifier = Modifier.fillMaxSize(),
contentPadding = PaddingValues(16.dp),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Welcome Header
item {
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.secondaryContainer
)
) {
Column(
modifier = Modifier.padding(16.dp)
) {
Text(
text = "Analysis Complete!",
style = MaterialTheme.typography.headlineSmall,
fontWeight = FontWeight.Bold
)
Spacer(modifier = Modifier.height(4.dp))
Text(
text = "Review your images below. Tap 'Pick Face' on group photos to choose which person to train on, or 'Replace' to swap out any image.",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSecondaryContainer.copy(alpha = 0.8f)
)
}
}
}
// Progress Summary
item {
ProgressSummaryCard(
totalImages = result.faceDetectionResults.size,
validImages = result.validImagesWithFaces.size,
requiredImages = 10,
isValid = result.isValid
)
}
// Image List Header
item {
Text(
text = "Your Images (${result.faceDetectionResults.size})",
style = MaterialTheme.typography.titleLarge,
fontWeight = FontWeight.Bold
)
}
// Image List with Actions
itemsIndexed(result.faceDetectionResults) { index, imageResult ->
ImageResultCard(
index = index + 1,
result = imageResult,
onReplace = { newUri ->
onReplaceImage(imageResult.uri, newUri)
},
onSelectFace = if (imageResult.faceCount > 1) {
{ onSelectFaceFromMultiple(imageResult) }
} else null
)
}
// Validation Issues (if any)
if (result.validationErrors.isNotEmpty()) {
item {
Spacer(modifier = Modifier.height(8.dp))
ValidationIssuesCard(errors = result.validationErrors)
}
}
// Action Button
item {
Spacer(modifier = Modifier.height(8.dp))
Button(
onClick = if (result.isValid) onContinue else onRetry,
modifier = Modifier.fillMaxWidth(),
enabled = result.isValid,
colors = ButtonDefaults.buttonColors(
containerColor = if (result.isValid)
MaterialTheme.colorScheme.primary
else
MaterialTheme.colorScheme.error.copy(alpha = 0.5f)
)
) {
Icon(
if (result.isValid) Icons.Default.CheckCircle else Icons.Default.Warning,
contentDescription = null
)
Spacer(modifier = Modifier.width(8.dp))
Text(
if (result.isValid)
"Continue to Training (${result.validImagesWithFaces.size} images)"
else
"Fix ${result.validationErrors.size} Issue${if (result.validationErrors.size != 1) "s" else ""} to Continue"
)
}
if (!result.isValid) {
Spacer(modifier = Modifier.height(8.dp))
Surface(
modifier = Modifier.fillMaxWidth(),
color = MaterialTheme.colorScheme.tertiaryContainer,
shape = RoundedCornerShape(8.dp)
) {
Row(
modifier = Modifier.padding(12.dp),
verticalAlignment = Alignment.CenterVertically
) {
Icon(
Icons.Default.Info,
contentDescription = null,
tint = MaterialTheme.colorScheme.onTertiaryContainer,
modifier = Modifier.size(20.dp)
)
Spacer(modifier = Modifier.width(8.dp))
Text(
text = "Tip: Use 'Replace' to swap problematic images, or 'Pick Face' to choose from group photos",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onTertiaryContainer
)
}
}
}
}
}
}
@Composable
private fun ProgressSummaryCard(
totalImages: Int,
validImages: Int,
requiredImages: Int,
isValid: Boolean
) {
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = if (isValid)
MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.5f)
else
MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f)
)
) {
Column(
modifier = Modifier.padding(16.dp)
) {
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.SpaceBetween,
verticalAlignment = Alignment.CenterVertically
) {
Text(
text = "Progress",
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.Bold
)
Icon(
imageVector = if (isValid) Icons.Default.CheckCircle else Icons.Default.Warning,
contentDescription = null,
tint = if (isValid)
MaterialTheme.colorScheme.primary
else
MaterialTheme.colorScheme.error,
modifier = Modifier.size(32.dp)
)
}
Spacer(modifier = Modifier.height(12.dp))
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.SpaceEvenly
) {
StatItem(
label = "Total",
value = totalImages.toString(),
color = MaterialTheme.colorScheme.onSurface
)
StatItem(
label = "Valid",
value = validImages.toString(),
color = if (validImages >= requiredImages)
MaterialTheme.colorScheme.primary
else
MaterialTheme.colorScheme.error
)
StatItem(
label = "Need",
value = requiredImages.toString(),
color = MaterialTheme.colorScheme.onSurface.copy(alpha = 0.6f)
)
}
Spacer(modifier = Modifier.height(12.dp))
LinearProgressIndicator(
progress = { (validImages.toFloat() / requiredImages.toFloat()).coerceIn(0f, 1f) },
modifier = Modifier.fillMaxWidth(),
color = if (isValid) MaterialTheme.colorScheme.primary else MaterialTheme.colorScheme.error
)
}
}
}
@Composable
private fun StatItem(label: String, value: String, color: Color) {
Column(horizontalAlignment = Alignment.CenterHorizontally) {
Text(
text = value,
style = MaterialTheme.typography.headlineMedium,
fontWeight = FontWeight.Bold,
color = color
)
Text(
text = label,
style = MaterialTheme.typography.bodySmall,
color = color.copy(alpha = 0.7f)
)
}
}
@Composable
private fun ImageResultCard(
index: Int,
result: FaceDetectionHelper.FaceDetectionResult,
onReplace: (Uri) -> Unit,
onSelectFace: (() -> Unit)?
) {
val photoPickerLauncher = rememberLauncherForActivityResult(
contract = ActivityResultContracts.PickVisualMedia()
) { uri ->
uri?.let { onReplace(it) }
}
val status = when {
result.errorMessage != null -> ImageStatus.ERROR
!result.hasFace -> ImageStatus.NO_FACE
result.faceCount > 1 -> ImageStatus.MULTIPLE_FACES
result.faceCount == 1 -> ImageStatus.VALID
else -> ImageStatus.ERROR
}
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = when (status) {
ImageStatus.VALID -> MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f)
ImageStatus.MULTIPLE_FACES -> MaterialTheme.colorScheme.tertiaryContainer.copy(alpha = 0.4f)
else -> MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f)
}
)
) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(12.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
// Image Number Badge
Box(
modifier = Modifier
.size(40.dp)
.background(
color = when (status) {
ImageStatus.VALID -> MaterialTheme.colorScheme.primary
ImageStatus.MULTIPLE_FACES -> MaterialTheme.colorScheme.tertiary
else -> MaterialTheme.colorScheme.error
},
shape = CircleShape
),
contentAlignment = Alignment.Center
) {
Text(
text = index.toString(),
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.Bold,
color = Color.White
)
}
// Thumbnail
if (result.croppedFaceBitmap != null) {
Image(
bitmap = result.croppedFaceBitmap.asImageBitmap(),
contentDescription = "Face",
modifier = Modifier
.size(64.dp)
.clip(RoundedCornerShape(8.dp))
.border(
BorderStroke(
2.dp,
when (status) {
ImageStatus.VALID -> MaterialTheme.colorScheme.primary
ImageStatus.MULTIPLE_FACES -> MaterialTheme.colorScheme.tertiary
else -> MaterialTheme.colorScheme.error
}
),
RoundedCornerShape(8.dp)
),
contentScale = ContentScale.Crop
)
} else {
AsyncImage(
model = result.uri,
contentDescription = "Original image",
modifier = Modifier
.size(64.dp)
.clip(RoundedCornerShape(8.dp)),
contentScale = ContentScale.Crop
)
}
// Status and Info
Column(
modifier = Modifier.weight(1f)
) {
Row(verticalAlignment = Alignment.CenterVertically) {
Icon(
imageVector = when (status) {
ImageStatus.VALID -> Icons.Default.CheckCircle
ImageStatus.MULTIPLE_FACES -> Icons.Default.Info
else -> Icons.Default.Warning
},
contentDescription = null,
tint = when (status) {
ImageStatus.VALID -> MaterialTheme.colorScheme.primary
ImageStatus.MULTIPLE_FACES -> MaterialTheme.colorScheme.tertiary
else -> MaterialTheme.colorScheme.error
},
modifier = Modifier.size(20.dp)
)
Spacer(modifier = Modifier.width(4.dp))
Text(
text = when (status) {
ImageStatus.VALID -> "Face Detected"
ImageStatus.MULTIPLE_FACES -> "Multiple Faces (${result.faceCount})"
ImageStatus.NO_FACE -> "No Face Detected"
ImageStatus.ERROR -> "Error"
},
style = MaterialTheme.typography.bodyMedium,
fontWeight = FontWeight.SemiBold
)
}
Text(
text = result.uri.lastPathSegment ?: "Unknown",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant,
maxLines = 1
)
}
// Action Buttons
Column(
horizontalAlignment = Alignment.End,
verticalArrangement = Arrangement.spacedBy(4.dp)
) {
// Select Face button (for multiple faces)
if (onSelectFace != null) {
OutlinedButton(
onClick = onSelectFace,
modifier = Modifier.height(32.dp),
contentPadding = PaddingValues(horizontal = 12.dp, vertical = 0.dp),
colors = ButtonDefaults.outlinedButtonColors(
contentColor = MaterialTheme.colorScheme.tertiary
),
border = BorderStroke(1.dp, MaterialTheme.colorScheme.tertiary)
) {
Icon(
Icons.Default.Face,
contentDescription = null,
modifier = Modifier.size(16.dp)
)
Spacer(modifier = Modifier.width(4.dp))
Text("Pick Face", style = MaterialTheme.typography.bodySmall)
}
}
// Replace button
OutlinedButton(
onClick = {
photoPickerLauncher.launch(
PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly)
)
},
modifier = Modifier.height(32.dp),
contentPadding = PaddingValues(horizontal = 12.dp, vertical = 0.dp)
) {
Icon(
Icons.Default.Refresh,
contentDescription = null,
modifier = Modifier.size(16.dp)
)
Spacer(modifier = Modifier.width(4.dp))
Text("Replace", style = MaterialTheme.typography.bodySmall)
}
}
}
}
}
@Composable
private fun ValidationIssuesCard(errors: List<TrainingSanityChecker.ValidationError>) {
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f)
)
) {
Column(
modifier = Modifier.padding(16.dp),
verticalArrangement = Arrangement.spacedBy(8.dp)
) {
Row(verticalAlignment = Alignment.CenterVertically) {
Icon(
Icons.Default.Warning,
contentDescription = null,
tint = MaterialTheme.colorScheme.error
)
Spacer(modifier = Modifier.width(8.dp))
Text(
text = "Issues Found (${errors.size})",
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.error
)
}
Divider(color = MaterialTheme.colorScheme.error.copy(alpha = 0.3f))
errors.forEach { error ->
when (error) {
is TrainingSanityChecker.ValidationError.NoFaceDetected -> {
Text(
text = "${error.uris.size} image(s) without detected faces - use Replace button",
style = MaterialTheme.typography.bodyMedium
)
}
is TrainingSanityChecker.ValidationError.MultipleFacesDetected -> {
Text(
text = "${error.uri.lastPathSegment} has ${error.faceCount} faces - use Pick Face button",
style = MaterialTheme.typography.bodyMedium
)
}
is TrainingSanityChecker.ValidationError.DuplicateImages -> {
Text(
text = "${error.groups.size} duplicate image group(s) - replace duplicates",
style = MaterialTheme.typography.bodyMedium
)
}
is TrainingSanityChecker.ValidationError.InsufficientImages -> {
Text(
text = "• Need ${error.required} valid images, currently have ${error.available}",
style = MaterialTheme.typography.bodyMedium,
fontWeight = FontWeight.Bold
)
}
is TrainingSanityChecker.ValidationError.ImageLoadError -> {
Text(
text = "• Failed to load ${error.uri.lastPathSegment} - use Replace button",
style = MaterialTheme.typography.bodyMedium
)
}
}
}
}
}
}
@Composable
private fun ErrorView(
message: String,
onRetry: () -> Unit
) {
Column(
modifier = Modifier
@@ -32,65 +632,36 @@ fun ScanResultsScreen(
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
when (state) {
is ScanningState.Processing -> {
CircularProgressIndicator()
Spacer(Modifier.height(16.dp))
Text("Analyzing faces... ${state.current} / ${state.total}")
}
is ScanningState.Success -> {
Text(
text = "Analysis Complete!",
style = MaterialTheme.typography.headlineMedium
)
LazyColumn(
modifier = Modifier
.weight(1f)
.padding(vertical = 16.dp)
) {
// FIX: Ensure 'items' is the one that takes a List, not a count
items(state.results) { result ->
Row(
modifier = Modifier
.fillMaxWidth()
.padding(8.dp),
verticalAlignment = Alignment.CenterVertically
) {
Image(
painter = rememberAsyncImagePainter(result.uri),
contentDescription = null,
modifier = Modifier
.size(64.dp)
.clip(RoundedCornerShape(8.dp)),
contentScale = ContentScale.Crop
)
Spacer(Modifier.width(16.dp))
Column {
Text(if (result.faceCount > 0) "✅ Face Detected" else "❌ No Face")
if (result.hasMultipleFaces) {
Text(
text = "⚠️ Multiple faces (${result.faceCount})",
color = MaterialTheme.colorScheme.error,
style = MaterialTheme.typography.bodySmall
)
}
}
}
}
}
Button(
onClick = onFinish,
modifier = Modifier.fillMaxWidth(),
shape = RoundedCornerShape(8.dp)
) {
Text("Done")
}
}
// Add fallback for other states (Idle/RequiresCrop)
// so the compiler doesn't complain about non-exhaustive 'when'
else -> { }
Icon(
imageVector = Icons.Default.Close,
contentDescription = null,
modifier = Modifier.size(64.dp),
tint = MaterialTheme.colorScheme.error
)
Spacer(modifier = Modifier.height(16.dp))
Text(
text = "Error",
style = MaterialTheme.typography.titleLarge,
fontWeight = FontWeight.Bold
)
Spacer(modifier = Modifier.height(8.dp))
Text(
text = message,
style = MaterialTheme.typography.bodyMedium,
textAlign = TextAlign.Center
)
Spacer(modifier = Modifier.height(24.dp))
Button(onClick = onRetry) {
Icon(Icons.Default.Refresh, contentDescription = null)
Spacer(modifier = Modifier.width(8.dp))
Text("Try Again")
}
}
}
private enum class ImageStatus {
VALID,
MULTIPLE_FACES,
NO_FACE,
ERROR
}

View File

@@ -1,120 +1,253 @@
package com.placeholder.sherpai2.ui.trainingprep
import android.content.Context
import android.graphics.Rect
import android.app.Application
import android.graphics.Bitmap
import android.net.Uri
import androidx.lifecycle.ViewModel
import androidx.lifecycle.AndroidViewModel
import androidx.lifecycle.viewModelScope
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.domain.repository.ImageRepository
import com.placeholder.sherpai2.domain.repository.TaggingRepository
import dagger.hilt.android.lifecycle.HiltViewModel
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.sync.withPermit
import kotlinx.coroutines.tasks.await
import kotlinx.coroutines.withContext
import javax.inject.Inject
// 1. DEFINE THESE AT TOP LEVEL (Outside the class) so the UI can see them
sealed class ScanningState {
object Idle : ScanningState()
data class Processing(val current: Int, val total: Int) : ScanningState()
data class RequiresCrop(val uri: Uri, val faceBoxes: List<Rect>, val remainingUris: List<Uri>) : ScanningState()
data class Success(val results: List<ScanResult>) : ScanningState()
data class Processing(val progress: Int, val total: Int) : ScanningState()
data class Success(
val sanityCheckResult: TrainingSanityChecker.SanityCheckResult
) : ScanningState()
data class Error(val message: String) : ScanningState()
}
data class ScanResult(
val uri: Uri,
val faceCount: Int,
val hasMultipleFaces: Boolean = faceCount > 1
)
@HiltViewModel
class TrainViewModel @Inject constructor(
@ApplicationContext private val context: Context,
private val imageRepository: ImageRepository,
private val taggingRepository: TaggingRepository
) : ViewModel() {
application: Application
) : AndroidViewModel(application) {
private val sanityChecker = TrainingSanityChecker(application)
private val faceDetectionHelper = FaceDetectionHelper(application)
private val _uiState = MutableStateFlow<ScanningState>(ScanningState.Idle)
val uiState: StateFlow<ScanningState> = _uiState.asStateFlow()
private val semaphore = Semaphore(2)
private val finalResults = mutableListOf<ScanResult>()
// Keep track of current images for replacements
private var currentImageUris: List<Uri> = emptyList()
fun scanAndTagFaces(uris: List<Uri>) = viewModelScope.launch {
// Goal: Deduplicate by SHA256 before starting
val allImages = imageRepository.getAllImages().first()
val uriToShaMap = allImages.associate { it.image.imageUri to it.image.sha256 }
// Keep track of manual face selections (imageUri -> selectedFaceIndex)
private val manualFaceSelections = mutableMapOf<Uri, ManualFaceSelection>()
val uniqueUris = uris.distinctBy { uri ->
uriToShaMap[uri.toString()] ?: uri.toString()
}
data class ManualFaceSelection(
val faceIndex: Int,
val croppedFaceBitmap: Bitmap
)
processNext(uniqueUris)
/**
* Scan and validate images for training
*/
fun scanAndTagFaces(imageUris: List<Uri>) {
currentImageUris = imageUris
manualFaceSelections.clear()
performScan(imageUris)
}
private suspend fun processNext(remaining: List<Uri>) {
if (remaining.isEmpty()) {
_uiState.value = ScanningState.Success(finalResults.toList())
return
/**
* Replace a single image and re-scan
*/
fun replaceImage(oldUri: Uri, newUri: Uri) {
viewModelScope.launch {
val updatedUris = currentImageUris.toMutableList()
val index = updatedUris.indexOf(oldUri)
if (index != -1) {
updatedUris[index] = newUri
currentImageUris = updatedUris
// Remove manual selection for old URI if any
manualFaceSelections.remove(oldUri)
// Re-scan all images
performScan(currentImageUris)
}
}
}
/**
* User manually selected a face from a multi-face image
*/
fun selectFaceFromImage(imageUri: Uri, faceIndex: Int, croppedFaceBitmap: Bitmap) {
manualFaceSelections[imageUri] = ManualFaceSelection(faceIndex, croppedFaceBitmap)
// Re-process the results with the manual selection
val currentState = _uiState.value
if (currentState is ScanningState.Success) {
val updatedResult = applyManualSelections(currentState.sanityCheckResult)
_uiState.value = ScanningState.Success(updatedResult)
}
}
/**
* Perform the actual scanning
*/
private fun performScan(imageUris: List<Uri>) {
viewModelScope.launch {
try {
_uiState.value = ScanningState.Processing(0, imageUris.size)
// Perform sanity checks
val result = sanityChecker.performSanityChecks(
imageUris = imageUris,
minImagesRequired = 10,
allowMultipleFaces = true, // Allow multiple faces - user can pick
duplicateSimilarityThreshold = 0.95
)
// Apply any manual face selections
val finalResult = applyManualSelections(result)
_uiState.value = ScanningState.Success(finalResult)
} catch (e: Exception) {
_uiState.value = ScanningState.Error(
e.message ?: "An unknown error occurred"
)
}
}
}
/**
* Apply manual face selections to the results
*/
private fun applyManualSelections(
result: TrainingSanityChecker.SanityCheckResult
): TrainingSanityChecker.SanityCheckResult {
// If no manual selections, return original
if (manualFaceSelections.isEmpty()) {
return result
}
val currentUri = remaining.first()
val nextList = remaining.drop(1)
_uiState.value = ScanningState.Processing(finalResults.size + 1, finalResults.size + remaining.size)
val detector = FaceDetection.getClient(faceOptions())
try {
val image = InputImage.fromFilePath(context, currentUri)
val faces = detector.process(image).await()
if (faces.size > 1) {
// FORCE USER TO CROP: Transition to RequiresCrop state
_uiState.value = ScanningState.RequiresCrop(
uri = currentUri,
faceBoxes = faces.map { it.boundingBox },
remainingUris = nextList
// Update face detection results with manual selections
val updatedFaceResults = result.faceDetectionResults.map { faceResult ->
val manualSelection = manualFaceSelections[faceResult.uri]
if (manualSelection != null) {
// Replace the cropped face with the manually selected one
faceResult.copy(
croppedFaceBitmap = manualSelection.croppedFaceBitmap,
// Treat as single face since user selected one
faceCount = 1
)
} else {
val faceCount = faces.size
if (faceCount > 0) tagImage(currentUri)
finalResults.add(ScanResult(currentUri, faceCount))
processNext(nextList)
faceResult
}
} catch (e: Exception) {
processNext(nextList)
} finally {
detector.close()
}
}
private suspend fun tagImage(uri: Uri) {
val allImages = imageRepository.getAllImages().first()
val imageId = allImages.find { it.image.imageUri == uri.toString() }?.image?.imageId
if (imageId != null) {
taggingRepository.addTagToImage(imageId, "face", "ML_KIT", 1.0f)
// Update valid images list
val updatedValidImages = updatedFaceResults
.filter { it.hasFace }
.filter { it.croppedFaceBitmap != null }
.filter { it.errorMessage == null }
.filter { it.faceCount >= 1 } // Now accept if user picked a face
.map { result ->
TrainingSanityChecker.ValidTrainingImage(
uri = result.uri,
croppedFaceBitmap = result.croppedFaceBitmap!!,
faceCount = result.faceCount
)
}
// Recalculate validation errors
val updatedErrors = result.validationErrors.toMutableList()
// Remove multiple face errors for images with manual selections
updatedErrors.removeAll { error ->
error is TrainingSanityChecker.ValidationError.MultipleFacesDetected &&
manualFaceSelections.containsKey(error.uri)
}
// Check if we have enough valid images now
if (updatedValidImages.size < 10) {
if (updatedErrors.none { it is TrainingSanityChecker.ValidationError.InsufficientImages }) {
updatedErrors.add(
TrainingSanityChecker.ValidationError.InsufficientImages(
required = 10,
available = updatedValidImages.size
)
)
}
} else {
// Remove insufficient images error if we now have enough
updatedErrors.removeAll { it is TrainingSanityChecker.ValidationError.InsufficientImages }
}
val isValid = updatedErrors.isEmpty() && updatedValidImages.size >= 10
return result.copy(
isValid = isValid,
faceDetectionResults = updatedFaceResults,
validationErrors = updatedErrors,
validImagesWithFaces = updatedValidImages
)
}
fun onFaceSelected(uri: Uri, box: Rect, remaining: List<Uri>) = viewModelScope.launch {
tagImage(uri)
finalResults.add(ScanResult(uri, 1))
processNext(remaining)
/**
* Get formatted error messages
*/
fun getFormattedErrors(result: TrainingSanityChecker.SanityCheckResult): List<String> {
return sanityChecker.formatValidationErrors(result.validationErrors)
}
private fun faceOptions() = FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
.build()
/**
* Reset to idle state
*/
fun reset() {
_uiState.value = ScanningState.Idle
currentImageUris = emptyList()
manualFaceSelections.clear()
}
override fun onCleared() {
super.onCleared()
sanityChecker.cleanup()
faceDetectionHelper.cleanup()
}
}
// Extension function to copy FaceDetectionResult with modifications
private fun FaceDetectionHelper.FaceDetectionResult.copy(
uri: Uri = this.uri,
hasFace: Boolean = this.hasFace,
faceCount: Int = this.faceCount,
faceBounds: List<android.graphics.Rect> = this.faceBounds,
croppedFaceBitmap: Bitmap? = this.croppedFaceBitmap,
errorMessage: String? = this.errorMessage
): FaceDetectionHelper.FaceDetectionResult {
return FaceDetectionHelper.FaceDetectionResult(
uri = uri,
hasFace = hasFace,
faceCount = faceCount,
faceBounds = faceBounds,
croppedFaceBitmap = croppedFaceBitmap,
errorMessage = errorMessage
)
}
// Extension function to copy SanityCheckResult with modifications
private fun TrainingSanityChecker.SanityCheckResult.copy(
isValid: Boolean = this.isValid,
faceDetectionResults: List<FaceDetectionHelper.FaceDetectionResult> = this.faceDetectionResults,
duplicateCheckResult: DuplicateImageDetector.DuplicateCheckResult = this.duplicateCheckResult,
validationErrors: List<TrainingSanityChecker.ValidationError> = this.validationErrors,
warnings: List<String> = this.warnings,
validImagesWithFaces: List<TrainingSanityChecker.ValidTrainingImage> = this.validImagesWithFaces
): TrainingSanityChecker.SanityCheckResult {
return TrainingSanityChecker.SanityCheckResult(
isValid = isValid,
faceDetectionResults = faceDetectionResults,
duplicateCheckResult = duplicateCheckResult,
validationErrors = validationErrors,
warnings = warnings,
validImagesWithFaces = validImagesWithFaces
)
}

View File

@@ -0,0 +1,188 @@
package com.placeholder.sherpai2.ui.trainingprep
import android.content.Context
import android.graphics.Bitmap
import android.net.Uri
/**
* Coordinates sanity checks for training images
*/
class TrainingSanityChecker(private val context: Context) {
private val faceDetectionHelper = FaceDetectionHelper(context)
private val duplicateDetector = DuplicateImageDetector(context)
data class SanityCheckResult(
val isValid: Boolean,
val faceDetectionResults: List<FaceDetectionHelper.FaceDetectionResult>,
val duplicateCheckResult: DuplicateImageDetector.DuplicateCheckResult,
val validationErrors: List<ValidationError>,
val warnings: List<String>,
val validImagesWithFaces: List<ValidTrainingImage>
)
data class ValidTrainingImage(
val uri: Uri,
val croppedFaceBitmap: Bitmap,
val faceCount: Int
)
sealed class ValidationError {
data class NoFaceDetected(val uris: List<Uri>) : ValidationError()
data class MultipleFacesDetected(val uri: Uri, val faceCount: Int) : ValidationError()
data class DuplicateImages(val groups: List<DuplicateImageDetector.DuplicateGroup>) : ValidationError()
data class InsufficientImages(val required: Int, val available: Int) : ValidationError()
data class ImageLoadError(val uri: Uri, val error: String) : ValidationError()
}
/**
* Perform comprehensive sanity checks on training images
*/
suspend fun performSanityChecks(
imageUris: List<Uri>,
minImagesRequired: Int = 10,
allowMultipleFaces: Boolean = false,
duplicateSimilarityThreshold: Double = 0.95
): SanityCheckResult {
val validationErrors = mutableListOf<ValidationError>()
val warnings = mutableListOf<String>()
// Check minimum image count
if (imageUris.size < minImagesRequired) {
validationErrors.add(
ValidationError.InsufficientImages(
required = minImagesRequired,
available = imageUris.size
)
)
}
// Step 1: Detect faces in all images
val faceDetectionResults = faceDetectionHelper.detectFacesInImages(imageUris)
// Check for images without faces
val imagesWithoutFaces = faceDetectionResults.filter { !it.hasFace }
if (imagesWithoutFaces.isNotEmpty()) {
validationErrors.add(
ValidationError.NoFaceDetected(
uris = imagesWithoutFaces.map { it.uri }
)
)
}
// Check for images with errors
faceDetectionResults.filter { it.errorMessage != null }.forEach { result ->
validationErrors.add(
ValidationError.ImageLoadError(
uri = result.uri,
error = result.errorMessage ?: "Unknown error"
)
)
}
// Check for images with multiple faces
if (!allowMultipleFaces) {
faceDetectionResults.filter { it.faceCount > 1 }.forEach { result ->
validationErrors.add(
ValidationError.MultipleFacesDetected(
uri = result.uri,
faceCount = result.faceCount
)
)
}
} else {
faceDetectionResults.filter { it.faceCount > 1 }.forEach { result ->
warnings.add("Image ${result.uri.lastPathSegment} contains ${result.faceCount} faces. Using the largest detected face.")
}
}
// Step 2: Check for duplicate images
val duplicateCheckResult = duplicateDetector.checkForDuplicates(
uris = imageUris,
similarityThreshold = duplicateSimilarityThreshold
)
if (duplicateCheckResult.hasDuplicates) {
validationErrors.add(
ValidationError.DuplicateImages(
groups = duplicateCheckResult.duplicateGroups
)
)
}
// Step 3: Create list of valid training images
val validImagesWithFaces = faceDetectionResults
.filter { it.hasFace && it.croppedFaceBitmap != null }
.filter { allowMultipleFaces || it.faceCount == 1 }
.map { result ->
ValidTrainingImage(
uri = result.uri,
croppedFaceBitmap = result.croppedFaceBitmap!!,
faceCount = result.faceCount
)
}
// Check if we have enough valid images after all checks
if (validImagesWithFaces.size < minImagesRequired) {
val existingError = validationErrors.find { it is ValidationError.InsufficientImages }
if (existingError == null) {
validationErrors.add(
ValidationError.InsufficientImages(
required = minImagesRequired,
available = validImagesWithFaces.size
)
)
}
}
val isValid = validationErrors.isEmpty() && validImagesWithFaces.size >= minImagesRequired
return SanityCheckResult(
isValid = isValid,
faceDetectionResults = faceDetectionResults,
duplicateCheckResult = duplicateCheckResult,
validationErrors = validationErrors,
warnings = warnings,
validImagesWithFaces = validImagesWithFaces
)
}
/**
* Format validation errors into human-readable messages
*/
fun formatValidationErrors(errors: List<ValidationError>): List<String> {
return errors.map { error ->
when (error) {
is ValidationError.NoFaceDetected -> {
val count = error.uris.size
val images = error.uris.joinToString(", ") { it.lastPathSegment ?: "Unknown" }
"No face detected in $count image(s): $images"
}
is ValidationError.MultipleFacesDetected -> {
"Multiple faces (${error.faceCount}) detected in: ${error.uri.lastPathSegment}"
}
is ValidationError.DuplicateImages -> {
val count = error.groups.size
val details = error.groups.joinToString("\n") { group ->
" - ${group.images.size} duplicates: ${group.images.joinToString(", ") { it.lastPathSegment ?: "Unknown" }}"
}
"Found $count duplicate group(s):\n$details"
}
is ValidationError.InsufficientImages -> {
"Insufficient images: need ${error.required}, but only ${error.available} valid images available"
}
is ValidationError.ImageLoadError -> {
"Failed to load image ${error.uri.lastPathSegment}: ${error.error}"
}
}
}
}
/**
* Clean up resources
*/
fun cleanup() {
faceDetectionHelper.cleanup()
}
}

View File

@@ -0,0 +1,78 @@
package com.placeholder.sherpai2.ui.trainingprep
import android.app.Application
import android.net.Uri
import androidx.lifecycle.AndroidViewModel
import androidx.lifecycle.viewModelScope
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch
/**
* ViewModel for managing training image sanity checks
*/
class TrainingSanityViewModel(application: Application) : AndroidViewModel(application) {
private val sanityChecker = TrainingSanityChecker(application)
private val _uiState = MutableStateFlow<TrainingSanityUiState>(TrainingSanityUiState.Idle)
val uiState: StateFlow<TrainingSanityUiState> = _uiState.asStateFlow()
sealed class TrainingSanityUiState {
object Idle : TrainingSanityUiState()
object Checking : TrainingSanityUiState()
data class Success(
val result: TrainingSanityChecker.SanityCheckResult
) : TrainingSanityUiState()
data class Error(val message: String) : TrainingSanityUiState()
}
/**
* Perform sanity checks on selected images
*/
fun checkImages(
imageUris: List<Uri>,
minImagesRequired: Int = 10,
allowMultipleFaces: Boolean = false,
duplicateSimilarityThreshold: Double = 0.95
) {
viewModelScope.launch {
try {
_uiState.value = TrainingSanityUiState.Checking
val result = sanityChecker.performSanityChecks(
imageUris = imageUris,
minImagesRequired = minImagesRequired,
allowMultipleFaces = allowMultipleFaces,
duplicateSimilarityThreshold = duplicateSimilarityThreshold
)
_uiState.value = TrainingSanityUiState.Success(result)
} catch (e: Exception) {
_uiState.value = TrainingSanityUiState.Error(
e.message ?: "An unknown error occurred during sanity checks"
)
}
}
}
/**
* Reset the UI state
*/
fun resetState() {
_uiState.value = TrainingSanityUiState.Idle
}
/**
* Get formatted error messages from validation result
*/
fun getFormattedErrors(result: TrainingSanityChecker.SanityCheckResult): List<String> {
return sanityChecker.formatValidationErrors(result.validationErrors)
}
override fun onCleared() {
super.onCleared()
sanityChecker.cleanup()
}
}