face detection + multi faces check
filtering before crop prompt - do we need to have user crop photos with only one face?
This commit is contained in:
@@ -72,4 +72,10 @@ dependencies {
|
||||
|
||||
// Coil Images
|
||||
implementation(libs.coil.compose)
|
||||
|
||||
// ML Kit
|
||||
implementation(libs.mlkit.face.detection)
|
||||
implementation(libs.kotlinx.coroutines.play.services)
|
||||
|
||||
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
package com.placeholder.sherpai2.ui.navigation
|
||||
|
||||
import android.net.Uri
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.navigation.NavHostController
|
||||
import androidx.navigation.NavType
|
||||
import androidx.navigation.compose.NavHost
|
||||
@@ -16,6 +18,19 @@ import java.net.URLDecoder
|
||||
import java.net.URLEncoder
|
||||
import com.placeholder.sherpai2.ui.tour.TourViewModel
|
||||
import com.placeholder.sherpai2.ui.tour.TourScreen
|
||||
import com.placeholder.sherpai2.ui.trainingprep.ImageSelectorScreen
|
||||
import com.placeholder.sherpai2.ui.trainingprep.TrainingScreen
|
||||
import com.placeholder.sherpai2.ui.navigation.AppRoutes
|
||||
import com.placeholder.sherpai2.ui.navigation.AppRoutes.ScanResultsScreen
|
||||
import com.placeholder.sherpai2.ui.trainingprep.ScanningState
|
||||
import com.placeholder.sherpai2.ui.trainingprep.TrainViewModel
|
||||
import androidx.compose.runtime.LaunchedEffect
|
||||
import androidx.compose.runtime.collectAsState
|
||||
import androidx.compose.runtime.getValue
|
||||
import com.placeholder.sherpai2.ui.trainingprep.ScanResultsScreen
|
||||
|
||||
|
||||
|
||||
@Composable
|
||||
fun AppNavHost(
|
||||
navController: NavHostController,
|
||||
@@ -71,13 +86,60 @@ fun AppNavHost(
|
||||
)
|
||||
}
|
||||
|
||||
/** TRAINING FLOW **/
|
||||
composable(AppRoutes.TRAIN) { entry ->
|
||||
val trainViewModel: TrainViewModel = hiltViewModel()
|
||||
val uiState by trainViewModel.uiState.collectAsState()
|
||||
|
||||
// Observe the result from the ImageSelector
|
||||
val selectedUris = entry.savedStateHandle.get<List<Uri>>("selected_image_uris")
|
||||
|
||||
// If we have new URIs and we are currently Idle, start scanning
|
||||
LaunchedEffect(selectedUris) {
|
||||
if (selectedUris != null && uiState is ScanningState.Idle) {
|
||||
trainViewModel.scanAndTagFaces(selectedUris)
|
||||
// Clear the handle so it doesn't re-trigger on configuration change
|
||||
entry.savedStateHandle.remove<List<Uri>>("selected_image_uris")
|
||||
}
|
||||
}
|
||||
|
||||
if (uiState is ScanningState.Idle) {
|
||||
// Initial state: Show start button or prompt
|
||||
TrainingScreen(
|
||||
onSelectImages = { navController.navigate(AppRoutes.IMAGE_SELECTOR) }
|
||||
)
|
||||
} else {
|
||||
// Processing or Success state: Show the results screen
|
||||
ScanResultsScreen(
|
||||
state = uiState,
|
||||
onFinish = {
|
||||
navController.navigate(AppRoutes.SEARCH) {
|
||||
popUpTo(AppRoutes.TRAIN) { inclusive = true }
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
composable(AppRoutes.IMAGE_SELECTOR) {
|
||||
ImageSelectorScreen(
|
||||
onImagesSelected = { uris ->
|
||||
navController.previousBackStackEntry
|
||||
?.savedStateHandle
|
||||
?.set("selected_image_uris", uris)
|
||||
|
||||
navController.popBackStack()
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/** DUMMY SCREENS FOR OTHER DRAWER ITEMS **/
|
||||
//composable(AppRoutes.TOUR) { DummyScreen("Tour (stub)") }
|
||||
composable(AppRoutes.MODELS) { DummyScreen("Models (stub)") }
|
||||
composable(AppRoutes.INVENTORY) { DummyScreen("Inventory (stub)") }
|
||||
composable(AppRoutes.TRAIN) { DummyScreen("Train (stub)") }
|
||||
//composable(AppRoutes.TRAIN) { DummyScreen("Train (stub)") }
|
||||
composable(AppRoutes.TAGS) { DummyScreen("Tags (stub)") }
|
||||
composable(AppRoutes.UPLOAD) { DummyScreen("Upload (stub)") }
|
||||
composable(AppRoutes.SETTINGS) { DummyScreen("Settings (stub)") }
|
||||
composable(AppRoutes.SETTINGS) { DummyScreen("Settings (stub)") }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,4 +20,13 @@ object AppRoutes {
|
||||
const val UPLOAD = "upload"
|
||||
const val SETTINGS = "settings"
|
||||
const val IMAGE_DETAIL = "IMAGE_DETAIL"
|
||||
|
||||
const val CROP_SCREEN = "CROP_SCREEN"
|
||||
const val IMAGE_SELECTOR = "Image Selection"
|
||||
const val TRAINING_SCREEN = "TRAINING_SCREEN"
|
||||
|
||||
const val ScanResultsScreen = "First Scan Results"
|
||||
|
||||
|
||||
//const val IMAGE_DETAIL = "IMAGE_DETAIL"
|
||||
}
|
||||
|
||||
@@ -7,7 +7,10 @@ import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.automirrored.filled.Label
|
||||
import androidx.compose.material.icons.automirrored.filled.List
|
||||
import androidx.compose.material.icons.filled.*
|
||||
import androidx.compose.material3.HorizontalDivider
|
||||
import com.placeholder.sherpai2.ui.navigation.AppRoutes
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@@ -26,16 +29,20 @@ fun AppDrawerContent(
|
||||
modifier = Modifier.padding(16.dp)
|
||||
)
|
||||
|
||||
Divider(Modifier.fillMaxWidth(), thickness = DividerDefaults.Thickness)
|
||||
HorizontalDivider(
|
||||
Modifier.fillMaxWidth(),
|
||||
thickness = DividerDefaults.Thickness,
|
||||
color = DividerDefaults.color
|
||||
)
|
||||
|
||||
// Main drawer items
|
||||
val mainItems = listOf(
|
||||
Triple(AppRoutes.SEARCH, "Search", Icons.Default.Search),
|
||||
Triple(AppRoutes.TOUR, "Tour", Icons.Default.Place),
|
||||
Triple(AppRoutes.MODELS, "Models", Icons.Default.ModelTraining),
|
||||
Triple(AppRoutes.INVENTORY, "Inventory", Icons.Default.List),
|
||||
Triple(AppRoutes.INVENTORY, "Inventory", Icons.AutoMirrored.Filled.List),
|
||||
Triple(AppRoutes.TRAIN, "Train", Icons.Default.Train),
|
||||
Triple(AppRoutes.TAGS, "Tags", Icons.Default.Label)
|
||||
Triple(AppRoutes.TAGS, "Tags", Icons.AutoMirrored.Filled.Label)
|
||||
)
|
||||
|
||||
Column(modifier = Modifier.padding(vertical = 8.dp)) {
|
||||
|
||||
@@ -15,8 +15,8 @@ import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import com.placeholder.sherpai2.data.local.model.ImageWithEverything
|
||||
|
||||
@Composable
|
||||
fun TourScreen(viewModel: TourViewModel = hiltViewModel()) {
|
||||
val images by viewModel.recentImages.collectAsState()
|
||||
fun TourScreen(tourViewModel: TourViewModel = hiltViewModel(), onImageClick: (String) -> Unit) {
|
||||
val images by tourViewModel.recentImages.collectAsState()
|
||||
|
||||
Column(modifier = Modifier.fillMaxSize()) {
|
||||
// Header with image count
|
||||
@@ -42,11 +42,11 @@ fun TourScreen(viewModel: TourViewModel = hiltViewModel()) {
|
||||
fun ImageCard(image: ImageWithEverything) {
|
||||
Card(modifier = Modifier.fillMaxWidth(), elevation = CardDefaults.cardElevation(4.dp)) {
|
||||
Column(modifier = Modifier.padding(12.dp)) {
|
||||
Text(text = image.imageUri, style = MaterialTheme.typography.bodyMedium)
|
||||
Text(text = image.tags.toString(), style = MaterialTheme.typography.bodyMedium)
|
||||
|
||||
// Tags row with placeholders if fewer than 3
|
||||
Row(modifier = Modifier.padding(top = 8.dp)) {
|
||||
val tags = image.tags.map { it.name } // adjust depending on your entity
|
||||
val tags = image.tags.map { it.tagId } // adjust depending on your entity
|
||||
tags.forEach { tag ->
|
||||
TagComposable(tag)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
package com.placeholder.sherpai2.ui.trainingprep
|
||||
|
||||
import android.net.Uri
|
||||
import androidx.activity.compose.rememberLauncherForActivityResult
|
||||
import androidx.activity.result.contract.ActivityResultContracts
|
||||
import androidx.compose.foundation.layout.*
|
||||
|
||||
import androidx.compose.foundation.lazy.grid.GridCells
|
||||
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
||||
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.AddPhotoAlternate
|
||||
import androidx.compose.material.icons.filled.Close
|
||||
import androidx.compose.material3.*
|
||||
import androidx.compose.runtime.*
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.layout.ContentScale
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.saveable.rememberSaveable
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import coil.compose.AsyncImage
|
||||
import androidx.compose.foundation.lazy.grid.items
|
||||
|
||||
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun ImageSelectorScreen(
|
||||
onImagesSelected: (List<Uri>) -> Unit
|
||||
) {
|
||||
//1. Persist state across configuration changes
|
||||
var selectedUris by rememberSaveable { mutableStateOf<List<Uri>>(emptyList()) }
|
||||
val context = LocalContext.current
|
||||
|
||||
val launcher = rememberLauncherForActivityResult(
|
||||
ActivityResultContracts.OpenMultipleDocuments()
|
||||
) { uris ->
|
||||
// 2. Take first 10 and try to persist permissions
|
||||
val limitedUris = uris.take(10)
|
||||
selectedUris = limitedUris
|
||||
}
|
||||
|
||||
Scaffold(
|
||||
topBar = { TopAppBar(title = { Text("Select Training Photos") }) }
|
||||
) { padding ->
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.padding(padding)
|
||||
.padding(16.dp)
|
||||
.fillMaxSize(),
|
||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
) {
|
||||
OutlinedCard(
|
||||
onClick = { launcher.launch(arrayOf("image/*")) },
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(24.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
Icon(Icons.Default.AddPhotoAlternate, contentDescription = null)
|
||||
Spacer(Modifier.height(8.dp))
|
||||
Text("Select up to 10 images of the person")
|
||||
Text(
|
||||
text = "${selectedUris.size} / 10 selected",
|
||||
style = MaterialTheme.typography.labelLarge,
|
||||
color = if (selectedUris.size == 10) MaterialTheme.colorScheme.error
|
||||
else if (selectedUris.isNotEmpty()) MaterialTheme.colorScheme.primary
|
||||
else MaterialTheme.colorScheme.outline
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Conditional rendering for empty state
|
||||
if (selectedUris.isEmpty()) {
|
||||
Box(Modifier
|
||||
.weight(1f)
|
||||
.fillMaxWidth(), contentAlignment = Alignment.Center) {
|
||||
Text("No images selected", style = MaterialTheme.typography.bodyMedium)
|
||||
}
|
||||
} else {
|
||||
LazyVerticalGrid(
|
||||
columns = GridCells.Fixed(3),
|
||||
modifier = Modifier.weight(1f),
|
||||
contentPadding = PaddingValues(4.dp)
|
||||
) {
|
||||
items(selectedUris, key = { it.toString() }) { uri ->
|
||||
Box(modifier = Modifier.padding(4.dp)) {
|
||||
AsyncImage(
|
||||
model = uri,
|
||||
contentDescription = null,
|
||||
modifier = Modifier
|
||||
.aspectRatio(1f)
|
||||
.clip(RoundedCornerShape(8.dp)),
|
||||
contentScale = ContentScale.Crop
|
||||
)
|
||||
// 4. Ability to remove specific images
|
||||
Surface(
|
||||
onClick = { selectedUris = selectedUris - uri },
|
||||
modifier = Modifier
|
||||
.align(Alignment.TopEnd)
|
||||
.padding(4.dp),
|
||||
shape = CircleShape,
|
||||
color = MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.8f)
|
||||
) {
|
||||
Icon(
|
||||
Icons.Default.Close,
|
||||
contentDescription = "Remove",
|
||||
modifier = Modifier.size(16.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Button(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
enabled = selectedUris.isNotEmpty(),
|
||||
onClick = { onImagesSelected(selectedUris) }
|
||||
) {
|
||||
Text("Start Face Detection")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package com.placeholder.sherpai2.ui.trainingprep
|
||||
|
||||
import android.net.Uri
|
||||
import androidx.compose.foundation.Image
|
||||
import androidx.compose.foundation.layout.*
|
||||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material3.*
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.layout.ContentScale
|
||||
import androidx.compose.ui.unit.dp
|
||||
import coil.compose.rememberAsyncImagePainter
|
||||
|
||||
@Composable
|
||||
fun ScanResultsScreen(
|
||||
state: ScanningState,
|
||||
onFinish: () -> Unit
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.fillMaxSize().padding(16.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.Center
|
||||
) {
|
||||
when (state) {
|
||||
is ScanningState.Processing -> {
|
||||
CircularProgressIndicator()
|
||||
Spacer(Modifier.height(16.dp))
|
||||
Text("Analyzing faces... ${state.current} / ${state.total}")
|
||||
}
|
||||
is ScanningState.Success -> {
|
||||
Text(
|
||||
text = "Analysis Complete!",
|
||||
style = MaterialTheme.typography.headlineMedium
|
||||
)
|
||||
|
||||
LazyColumn(modifier = Modifier.weight(1f).padding(vertical = 16.dp)) {
|
||||
items(state.results) { result ->
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth().padding(8.dp),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Image(
|
||||
painter = rememberAsyncImagePainter(result.uri),
|
||||
contentDescription = null,
|
||||
modifier = Modifier.size(60.dp).clip(RoundedCornerShape(8.dp)),
|
||||
contentScale = ContentScale.Crop
|
||||
)
|
||||
Spacer(Modifier.width(16.dp))
|
||||
Column {
|
||||
Text(if (result.faceCount > 0) "✅ Face Detected" else "❌ No Face")
|
||||
if (result.hasMultipleFaces) {
|
||||
Text(
|
||||
"⚠️ Multiple faces (${result.faceCount})",
|
||||
color = MaterialTheme.colorScheme.error,
|
||||
style = MaterialTheme.typography.bodySmall
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Button(onClick = onFinish, modifier = Modifier.fillMaxWidth()) {
|
||||
Text("Done")
|
||||
}
|
||||
}
|
||||
else -> {}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
package com.placeholder.sherpai2.ui.trainingprep
|
||||
|
||||
import android.content.Context
|
||||
import android.net.Uri
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.google.mlkit.vision.common.InputImage
|
||||
import com.google.mlkit.vision.face.FaceDetection
|
||||
import com.google.mlkit.vision.face.FaceDetectorOptions
|
||||
import com.placeholder.sherpai2.domain.repository.ImageRepository
|
||||
import com.placeholder.sherpai2.domain.repository.TaggingRepository
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.awaitAll
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.sync.Semaphore
|
||||
import kotlinx.coroutines.sync.withPermit
|
||||
import kotlinx.coroutines.tasks.await
|
||||
import kotlinx.coroutines.withContext
|
||||
import javax.inject.Inject
|
||||
|
||||
sealed class ScanningState {
|
||||
object Idle : ScanningState()
|
||||
data class Processing(val current: Int, val total: Int) : ScanningState()
|
||||
data class Success(val results: List<ScanResult>) : ScanningState()
|
||||
}
|
||||
|
||||
data class ScanResult(
|
||||
val uri: Uri,
|
||||
val faceCount: Int,
|
||||
val hasMultipleFaces: Boolean = faceCount > 1
|
||||
)
|
||||
|
||||
@HiltViewModel
|
||||
class TrainViewModel @Inject constructor(
|
||||
@ApplicationContext private val context: Context,
|
||||
private val imageRepository: ImageRepository,
|
||||
private val taggingRepository: TaggingRepository
|
||||
) : ViewModel() {
|
||||
|
||||
private val _uiState = MutableStateFlow<ScanningState>(ScanningState.Idle)
|
||||
val uiState: StateFlow<ScanningState> = _uiState.asStateFlow()
|
||||
|
||||
private val semaphore = Semaphore(2)
|
||||
|
||||
fun scanAndTagFaces(uris: List<Uri>) = viewModelScope.launch {
|
||||
val total = uris.size
|
||||
_uiState.value = ScanningState.Processing(0, total)
|
||||
|
||||
val detector = FaceDetection.getClient(faceOptions())
|
||||
val allImages = imageRepository.getAllImages().first()
|
||||
val uriToIdMap = allImages.associate { it.image.imageUri to it.image.imageId }
|
||||
|
||||
var completedCount = 0
|
||||
|
||||
val scanResults = withContext(Dispatchers.Default) {
|
||||
uris.map { uri ->
|
||||
async {
|
||||
semaphore.withPermit {
|
||||
val faceCount = detectFaceCount(detector, uri)
|
||||
|
||||
// Tagging logic
|
||||
if (faceCount > 0) {
|
||||
uriToIdMap[uri.toString()]?.let { id ->
|
||||
taggingRepository.addTagToImage(id, "face", "ML_KIT", 1.0f)
|
||||
if (faceCount > 1) {
|
||||
taggingRepository.addTagToImage(id, "multiple_faces", "ML_KIT", 1.0f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
completedCount++
|
||||
_uiState.value = ScanningState.Processing(completedCount, total)
|
||||
|
||||
ScanResult(uri, faceCount)
|
||||
}
|
||||
}
|
||||
}.awaitAll()
|
||||
}
|
||||
|
||||
detector.close()
|
||||
_uiState.value = ScanningState.Success(scanResults)
|
||||
}
|
||||
|
||||
private suspend fun detectFaceCount(
|
||||
detector: com.google.mlkit.vision.face.FaceDetector,
|
||||
uri: Uri
|
||||
): Int = withContext(Dispatchers.IO) {
|
||||
return@withContext try {
|
||||
val image = InputImage.fromFilePath(context, uri)
|
||||
val faces = detector.process(image).await()
|
||||
faces.size // Returns actual count
|
||||
} catch (e: Exception) {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
private fun faceOptions() = FaceDetectorOptions.Builder()
|
||||
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_FAST)
|
||||
.build()
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package com.placeholder.sherpai2.ui.trainingprep
|
||||
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.Scaffold
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TopAppBar
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.hilt.lifecycle.viewmodel.compose.hiltViewModel
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun TrainingScreen(
|
||||
onSelectImages: () -> Unit
|
||||
) {
|
||||
Scaffold(
|
||||
topBar = {
|
||||
TopAppBar(
|
||||
title = { Text("Training") }
|
||||
)
|
||||
}
|
||||
) { padding ->
|
||||
Button(
|
||||
modifier = Modifier.padding(padding),
|
||||
onClick = onSelectImages
|
||||
) {
|
||||
Text("Select Images")
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user