face detection + multi faces check

filtering before crop prompt - do we need to have user crop photos with only one face?
This commit is contained in:
genki
2026-01-01 01:02:42 -05:00
parent 3f15bfabc1
commit dba64b89b6
10 changed files with 444 additions and 9 deletions

View File

@@ -72,4 +72,10 @@ dependencies {
// Coil Images
implementation(libs.coil.compose)
// ML Kit
implementation(libs.mlkit.face.detection)
implementation(libs.kotlinx.coroutines.play.services)
}

View File

@@ -1,8 +1,10 @@
package com.placeholder.sherpai2.ui.navigation
import android.net.Uri
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.ViewModel
import androidx.navigation.NavHostController
import androidx.navigation.NavType
import androidx.navigation.compose.NavHost
@@ -16,6 +18,19 @@ import java.net.URLDecoder
import java.net.URLEncoder
import com.placeholder.sherpai2.ui.tour.TourViewModel
import com.placeholder.sherpai2.ui.tour.TourScreen
import com.placeholder.sherpai2.ui.trainingprep.ImageSelectorScreen
import com.placeholder.sherpai2.ui.trainingprep.TrainingScreen
import com.placeholder.sherpai2.ui.navigation.AppRoutes
import com.placeholder.sherpai2.ui.navigation.AppRoutes.ScanResultsScreen
import com.placeholder.sherpai2.ui.trainingprep.ScanningState
import com.placeholder.sherpai2.ui.trainingprep.TrainViewModel
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue
import com.placeholder.sherpai2.ui.trainingprep.ScanResultsScreen
@Composable
fun AppNavHost(
navController: NavHostController,
@@ -71,11 +86,58 @@ fun AppNavHost(
)
}
/** TRAINING FLOW **/
composable(AppRoutes.TRAIN) { entry ->
val trainViewModel: TrainViewModel = hiltViewModel()
val uiState by trainViewModel.uiState.collectAsState()
// Observe the result from the ImageSelector
val selectedUris = entry.savedStateHandle.get<List<Uri>>("selected_image_uris")
// If we have new URIs and we are currently Idle, start scanning
LaunchedEffect(selectedUris) {
if (selectedUris != null && uiState is ScanningState.Idle) {
trainViewModel.scanAndTagFaces(selectedUris)
// Clear the handle so it doesn't re-trigger on configuration change
entry.savedStateHandle.remove<List<Uri>>("selected_image_uris")
}
}
if (uiState is ScanningState.Idle) {
// Initial state: Show start button or prompt
TrainingScreen(
onSelectImages = { navController.navigate(AppRoutes.IMAGE_SELECTOR) }
)
} else {
// Processing or Success state: Show the results screen
ScanResultsScreen(
state = uiState,
onFinish = {
navController.navigate(AppRoutes.SEARCH) {
popUpTo(AppRoutes.TRAIN) { inclusive = true }
}
}
)
}
}
composable(AppRoutes.IMAGE_SELECTOR) {
ImageSelectorScreen(
onImagesSelected = { uris ->
navController.previousBackStackEntry
?.savedStateHandle
?.set("selected_image_uris", uris)
navController.popBackStack()
}
)
}
/** DUMMY SCREENS FOR OTHER DRAWER ITEMS **/
//composable(AppRoutes.TOUR) { DummyScreen("Tour (stub)") }
composable(AppRoutes.MODELS) { DummyScreen("Models (stub)") }
composable(AppRoutes.INVENTORY) { DummyScreen("Inventory (stub)") }
composable(AppRoutes.TRAIN) { DummyScreen("Train (stub)") }
//composable(AppRoutes.TRAIN) { DummyScreen("Train (stub)") }
composable(AppRoutes.TAGS) { DummyScreen("Tags (stub)") }
composable(AppRoutes.UPLOAD) { DummyScreen("Upload (stub)") }
composable(AppRoutes.SETTINGS) { DummyScreen("Settings (stub)") }

View File

@@ -20,4 +20,13 @@ object AppRoutes {
const val UPLOAD = "upload"
const val SETTINGS = "settings"
const val IMAGE_DETAIL = "IMAGE_DETAIL"
const val CROP_SCREEN = "CROP_SCREEN"
const val IMAGE_SELECTOR = "Image Selection"
const val TRAINING_SCREEN = "TRAINING_SCREEN"
const val ScanResultsScreen = "First Scan Results"
//const val IMAGE_DETAIL = "IMAGE_DETAIL"
}

View File

@@ -7,7 +7,10 @@ import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.automirrored.filled.Label
import androidx.compose.material.icons.automirrored.filled.List
import androidx.compose.material.icons.filled.*
import androidx.compose.material3.HorizontalDivider
import com.placeholder.sherpai2.ui.navigation.AppRoutes
@OptIn(ExperimentalMaterial3Api::class)
@@ -26,16 +29,20 @@ fun AppDrawerContent(
modifier = Modifier.padding(16.dp)
)
Divider(Modifier.fillMaxWidth(), thickness = DividerDefaults.Thickness)
HorizontalDivider(
Modifier.fillMaxWidth(),
thickness = DividerDefaults.Thickness,
color = DividerDefaults.color
)
// Main drawer items
val mainItems = listOf(
Triple(AppRoutes.SEARCH, "Search", Icons.Default.Search),
Triple(AppRoutes.TOUR, "Tour", Icons.Default.Place),
Triple(AppRoutes.MODELS, "Models", Icons.Default.ModelTraining),
Triple(AppRoutes.INVENTORY, "Inventory", Icons.Default.List),
Triple(AppRoutes.INVENTORY, "Inventory", Icons.AutoMirrored.Filled.List),
Triple(AppRoutes.TRAIN, "Train", Icons.Default.Train),
Triple(AppRoutes.TAGS, "Tags", Icons.Default.Label)
Triple(AppRoutes.TAGS, "Tags", Icons.AutoMirrored.Filled.Label)
)
Column(modifier = Modifier.padding(vertical = 8.dp)) {

View File

@@ -15,8 +15,8 @@ import androidx.hilt.navigation.compose.hiltViewModel
import com.placeholder.sherpai2.data.local.model.ImageWithEverything
@Composable
fun TourScreen(viewModel: TourViewModel = hiltViewModel()) {
val images by viewModel.recentImages.collectAsState()
fun TourScreen(tourViewModel: TourViewModel = hiltViewModel(), onImageClick: (String) -> Unit) {
val images by tourViewModel.recentImages.collectAsState()
Column(modifier = Modifier.fillMaxSize()) {
// Header with image count
@@ -42,11 +42,11 @@ fun TourScreen(viewModel: TourViewModel = hiltViewModel()) {
fun ImageCard(image: ImageWithEverything) {
Card(modifier = Modifier.fillMaxWidth(), elevation = CardDefaults.cardElevation(4.dp)) {
Column(modifier = Modifier.padding(12.dp)) {
Text(text = image.imageUri, style = MaterialTheme.typography.bodyMedium)
Text(text = image.tags.toString(), style = MaterialTheme.typography.bodyMedium)
// Tags row with placeholders if fewer than 3
Row(modifier = Modifier.padding(top = 8.dp)) {
val tags = image.tags.map { it.name } // adjust depending on your entity
val tags = image.tags.map { it.tagId } // adjust depending on your entity
tags.forEach { tag ->
TagComposable(tag)
}

View File

@@ -0,0 +1,130 @@
package com.placeholder.sherpai2.ui.trainingprep
import android.net.Uri
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.AddPhotoAlternate
import androidx.compose.material.icons.filled.Close
import androidx.compose.material3.*
import androidx.compose.runtime.*
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.unit.dp
import androidx.compose.material3.Text
import androidx.compose.runtime.saveable.rememberSaveable
import androidx.compose.ui.draw.clip
import androidx.compose.ui.platform.LocalContext
import coil.compose.AsyncImage
import androidx.compose.foundation.lazy.grid.items
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun ImageSelectorScreen(
onImagesSelected: (List<Uri>) -> Unit
) {
//1. Persist state across configuration changes
var selectedUris by rememberSaveable { mutableStateOf<List<Uri>>(emptyList()) }
val context = LocalContext.current
val launcher = rememberLauncherForActivityResult(
ActivityResultContracts.OpenMultipleDocuments()
) { uris ->
// 2. Take first 10 and try to persist permissions
val limitedUris = uris.take(10)
selectedUris = limitedUris
}
Scaffold(
topBar = { TopAppBar(title = { Text("Select Training Photos") }) }
) { padding ->
Column(
modifier = Modifier
.padding(padding)
.padding(16.dp)
.fillMaxSize(),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
OutlinedCard(
onClick = { launcher.launch(arrayOf("image/*")) },
modifier = Modifier.fillMaxWidth()
) {
Column(
modifier = Modifier.padding(24.dp),
horizontalAlignment = Alignment.CenterHorizontally
) {
Icon(Icons.Default.AddPhotoAlternate, contentDescription = null)
Spacer(Modifier.height(8.dp))
Text("Select up to 10 images of the person")
Text(
text = "${selectedUris.size} / 10 selected",
style = MaterialTheme.typography.labelLarge,
color = if (selectedUris.size == 10) MaterialTheme.colorScheme.error
else if (selectedUris.isNotEmpty()) MaterialTheme.colorScheme.primary
else MaterialTheme.colorScheme.outline
)
}
}
// 3. Conditional rendering for empty state
if (selectedUris.isEmpty()) {
Box(Modifier
.weight(1f)
.fillMaxWidth(), contentAlignment = Alignment.Center) {
Text("No images selected", style = MaterialTheme.typography.bodyMedium)
}
} else {
LazyVerticalGrid(
columns = GridCells.Fixed(3),
modifier = Modifier.weight(1f),
contentPadding = PaddingValues(4.dp)
) {
items(selectedUris, key = { it.toString() }) { uri ->
Box(modifier = Modifier.padding(4.dp)) {
AsyncImage(
model = uri,
contentDescription = null,
modifier = Modifier
.aspectRatio(1f)
.clip(RoundedCornerShape(8.dp)),
contentScale = ContentScale.Crop
)
// 4. Ability to remove specific images
Surface(
onClick = { selectedUris = selectedUris - uri },
modifier = Modifier
.align(Alignment.TopEnd)
.padding(4.dp),
shape = CircleShape,
color = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.8f)
) {
Icon(
Icons.Default.Close,
contentDescription = "Remove",
modifier = Modifier.size(16.dp)
)
}
}
}
}
}
Button(
modifier = Modifier.fillMaxWidth(),
enabled = selectedUris.isNotEmpty(),
onClick = { onImagesSelected(selectedUris) }
) {
Text("Start Face Detection")
}
}
}
}

View File

@@ -0,0 +1,74 @@
package com.placeholder.sherpai2.ui.trainingprep
import android.net.Uri
import androidx.compose.foundation.Image
import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material3.*
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.unit.dp
import coil.compose.rememberAsyncImagePainter
@Composable
fun ScanResultsScreen(
state: ScanningState,
onFinish: () -> Unit
) {
Column(
modifier = Modifier.fillMaxSize().padding(16.dp),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
when (state) {
is ScanningState.Processing -> {
CircularProgressIndicator()
Spacer(Modifier.height(16.dp))
Text("Analyzing faces... ${state.current} / ${state.total}")
}
is ScanningState.Success -> {
Text(
text = "Analysis Complete!",
style = MaterialTheme.typography.headlineMedium
)
LazyColumn(modifier = Modifier.weight(1f).padding(vertical = 16.dp)) {
items(state.results) { result ->
Row(
modifier = Modifier.fillMaxWidth().padding(8.dp),
verticalAlignment = Alignment.CenterVertically
) {
Image(
painter = rememberAsyncImagePainter(result.uri),
contentDescription = null,
modifier = Modifier.size(60.dp).clip(RoundedCornerShape(8.dp)),
contentScale = ContentScale.Crop
)
Spacer(Modifier.width(16.dp))
Column {
Text(if (result.faceCount > 0) "✅ Face Detected" else "❌ No Face")
if (result.hasMultipleFaces) {
Text(
"⚠️ Multiple faces (${result.faceCount})",
color = MaterialTheme.colorScheme.error,
style = MaterialTheme.typography.bodySmall
)
}
}
}
}
}
Button(onClick = onFinish, modifier = Modifier.fillMaxWidth()) {
Text("Done")
}
}
else -> {}
}
}
}

View File

@@ -0,0 +1,107 @@
package com.placeholder.sherpai2.ui.trainingprep
import android.content.Context
import android.net.Uri
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.domain.repository.ImageRepository
import com.placeholder.sherpai2.domain.repository.TaggingRepository
import dagger.hilt.android.lifecycle.HiltViewModel
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.sync.withPermit
import kotlinx.coroutines.tasks.await
import kotlinx.coroutines.withContext
import javax.inject.Inject
sealed class ScanningState {
object Idle : ScanningState()
data class Processing(val current: Int, val total: Int) : ScanningState()
data class Success(val results: List<ScanResult>) : ScanningState()
}
data class ScanResult(
val uri: Uri,
val faceCount: Int,
val hasMultipleFaces: Boolean = faceCount > 1
)
@HiltViewModel
class TrainViewModel @Inject constructor(
@ApplicationContext private val context: Context,
private val imageRepository: ImageRepository,
private val taggingRepository: TaggingRepository
) : ViewModel() {
private val _uiState = MutableStateFlow<ScanningState>(ScanningState.Idle)
val uiState: StateFlow<ScanningState> = _uiState.asStateFlow()
private val semaphore = Semaphore(2)
fun scanAndTagFaces(uris: List<Uri>) = viewModelScope.launch {
val total = uris.size
_uiState.value = ScanningState.Processing(0, total)
val detector = FaceDetection.getClient(faceOptions())
val allImages = imageRepository.getAllImages().first()
val uriToIdMap = allImages.associate { it.image.imageUri to it.image.imageId }
var completedCount = 0
val scanResults = withContext(Dispatchers.Default) {
uris.map { uri ->
async {
semaphore.withPermit {
val faceCount = detectFaceCount(detector, uri)
// Tagging logic
if (faceCount > 0) {
uriToIdMap[uri.toString()]?.let { id ->
taggingRepository.addTagToImage(id, "face", "ML_KIT", 1.0f)
if (faceCount > 1) {
taggingRepository.addTagToImage(id, "multiple_faces", "ML_KIT", 1.0f)
}
}
}
completedCount++
_uiState.value = ScanningState.Processing(completedCount, total)
ScanResult(uri, faceCount)
}
}
}.awaitAll()
}
detector.close()
_uiState.value = ScanningState.Success(scanResults)
}
private suspend fun detectFaceCount(
detector: com.google.mlkit.vision.face.FaceDetector,
uri: Uri
): Int = withContext(Dispatchers.IO) {
return@withContext try {
val image = InputImage.fromFilePath(context, uri)
val faces = detector.process(image).await()
faces.size // Returns actual count
} catch (e: Exception) {
0
}
}
private fun faceOptions() = FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_FAST)
.build()
}

View File

@@ -0,0 +1,31 @@
package com.placeholder.sherpai2.ui.trainingprep
import androidx.compose.foundation.layout.padding
import androidx.compose.material3.Button
import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Scaffold
import androidx.compose.material3.Text
import androidx.compose.material3.TopAppBar
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.hilt.lifecycle.viewmodel.compose.hiltViewModel
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun TrainingScreen(
onSelectImages: () -> Unit
) {
Scaffold(
topBar = {
TopAppBar(
title = { Text("Training") }
)
}
) { padding ->
Button(
modifier = Modifier.padding(padding),
onClick = onSelectImages
) {
Text("Select Images")
}
}
}

View File

@@ -19,6 +19,10 @@ room = "2.8.4"
# Images
coil = "2.7.0"
#Face Detect
mlkit-face-detection = "16.1.6"
coroutines-play-services = "1.8.1"
[libraries]
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" }
androidx-lifecycle-runtime-ktx = { group = "androidx.lifecycle", name = "lifecycle-runtime-ktx", version.ref = "lifecycle" }
@@ -48,6 +52,11 @@ room-compiler = { group = "androidx.room", name = "room-compiler", version.ref =
# Misc
coil-compose = { group = "io.coil-kt", name = "coil-compose", version.ref = "coil" }
#Face Detect
mlkit-face-detection = { group = "com.google.mlkit", name = "face-detection", version.ref = "mlkit-face-detection"}
kotlinx-coroutines-play-services = {group = "org.jetbrains.kotlinx",name = "kotlinx-coroutines-play-services",version.ref = "coroutines-play-services"}
[plugins]
android-application = { id = "com.android.application", version.ref = "agp" }
kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }