Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 125 additions & 97 deletions core/src/main/java/com/google/adk/artifacts/GcsArtifactService.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
import com.google.common.base.Splitter;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Streams;
import com.google.genai.types.FileData;
import com.google.genai.types.Part;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -135,22 +135,25 @@ public Maybe<Part> loadArtifact(
.flatMapMaybe(
versions -> versions.isEmpty() ? Maybe.empty() : Maybe.just(max(versions))))
.flatMap(
versionToLoad -> {
String blobName = getBlobName(appName, userId, sessionId, filename, versionToLoad);
BlobId blobId = BlobId.of(bucketName, blobName);
versionToLoad ->
Maybe.fromCallable(
() -> {
String blobName =
getBlobName(appName, userId, sessionId, filename, versionToLoad);
BlobId blobId = BlobId.of(bucketName, blobName);

try {
Blob blob = storageClient.get(blobId);
if (blob == null || !blob.exists()) {
return Maybe.empty();
}
byte[] data = blob.getContent();
String mimeType = blob.getContentType();
return Maybe.just(Part.fromBytes(data, mimeType));
} catch (StorageException e) {
return Maybe.empty();
}
});
try {
Blob blob = storageClient.get(blobId);
if (blob == null || !blob.exists()) {
return null;
}
byte[] data = blob.getContent();
String mimeType = blob.getContentType();
return Part.fromBytes(data, mimeType);
} catch (StorageException e) {
return null;
}
}));
}

/**
Expand All @@ -164,34 +167,38 @@ public Maybe<Part> loadArtifact(
@Override
public Single<ListArtifactsResponse> listArtifactKeys(
String appName, String userId, String sessionId) {
Set<String> filenames = new HashSet<>();
return Single.fromCallable(
() -> {
Set<String> filenames = new HashSet<>();

// List session-specific files
String sessionPrefix = String.format("%s/%s/%s/", appName, userId, sessionId);
try {
for (Blob blob :
storageClient.list(bucketName, BlobListOption.prefix(sessionPrefix)).iterateAll()) {
List<String> parts = Splitter.on('/').splitToList(blob.getName());
filenames.add(parts.get(3)); // appName/userId/sessionId/filename/version
}
} catch (StorageException e) {
throw new VerifyException("Failed to list session artifacts from GCS", e);
}
// List session-specific files
String sessionPrefix = String.format("%s/%s/%s/", appName, userId, sessionId);
try {
for (Blob blob :
storageClient.list(bucketName, BlobListOption.prefix(sessionPrefix)).iterateAll()) {
List<String> parts = Splitter.on('/').splitToList(blob.getName());
filenames.add(parts.get(3)); // appName/userId/sessionId/filename/version
}
} catch (StorageException e) {
throw new VerifyException("Failed to list session artifacts from GCS", e);
}

// List user-namespace files
String userPrefix = String.format("%s/%s/user/", appName, userId);
try {
for (Blob blob :
storageClient.list(bucketName, BlobListOption.prefix(userPrefix)).iterateAll()) {
List<String> parts = Splitter.on('/').splitToList(blob.getName());
filenames.add(parts.get(3)); // appName/userId/user/filename/version
}
} catch (StorageException e) {
throw new VerifyException("Failed to list user artifacts from GCS", e);
}
// List user-namespace files
String userPrefix = String.format("%s/%s/user/", appName, userId);
try {
for (Blob blob :
storageClient.list(bucketName, BlobListOption.prefix(userPrefix)).iterateAll()) {
List<String> parts = Splitter.on('/').splitToList(blob.getName());
filenames.add(parts.get(3)); // appName/userId/user/filename/version
}
} catch (StorageException e) {
throw new VerifyException("Failed to list user artifacts from GCS", e);
}

return Single.just(
ListArtifactsResponse.builder().filenames(ImmutableList.sortedCopyOf(filenames)).build());
return ListArtifactsResponse.builder()
.filenames(ImmutableList.sortedCopyOf(filenames))
.build();
});
}

/**
Expand All @@ -206,22 +213,30 @@ public Single<ListArtifactsResponse> listArtifactKeys(
@Override
public Completable deleteArtifact(
String appName, String userId, String sessionId, String filename) {
ImmutableList<Integer> versions =
listVersions(appName, userId, sessionId, filename).blockingGet();
List<BlobId> blobIdsToDelete = new ArrayList<>();
for (int version : versions) {
String blobName = getBlobName(appName, userId, sessionId, filename, version);
blobIdsToDelete.add(BlobId.of(bucketName, blobName));
}
return listVersions(appName, userId, sessionId, filename)
.flatMapCompletable(
versions -> {
if (versions.isEmpty()) {
return Completable.complete();
}
ImmutableList<BlobId> blobIdsToDelete =
versions.stream()
.map(
version ->
BlobId.of(
bucketName,
getBlobName(appName, userId, sessionId, filename, version)))
.collect(ImmutableList.toImmutableList());

if (!blobIdsToDelete.isEmpty()) {
try {
var unused = storageClient.delete(blobIdsToDelete);
} catch (StorageException e) {
throw new VerifyException("Failed to delete artifact versions from GCS", e);
}
}
return Completable.complete();
return Completable.fromAction(
() -> {
try {
var unused = storageClient.delete(blobIdsToDelete);
} catch (StorageException e) {
throw new VerifyException("Failed to delete artifact versions from GCS", e);
}
});
});
}

/**
Expand All @@ -236,20 +251,29 @@ public Completable deleteArtifact(
@Override
public Single<ImmutableList<Integer>> listVersions(
String appName, String userId, String sessionId, String filename) {
String prefix = getBlobPrefix(appName, userId, sessionId, filename);
List<Integer> versions = new ArrayList<>();
try {
for (Blob blob : storageClient.list(bucketName, BlobListOption.prefix(prefix)).iterateAll()) {
String name = blob.getName();
int versionDelimiterIndex = name.lastIndexOf('/'); // immediately before the version number
if (versionDelimiterIndex != -1 && versionDelimiterIndex < name.length() - 1) {
versions.add(Integer.parseInt(name.substring(versionDelimiterIndex + 1)));
}
}
return Single.just(ImmutableList.sortedCopyOf(versions));
} catch (StorageException e) {
return Single.just(ImmutableList.of());
}
return Single.fromCallable(
() -> {
String prefix = getBlobPrefix(appName, userId, sessionId, filename);
try {
return Streams.stream(
storageClient.list(bucketName, BlobListOption.prefix(prefix)).iterateAll())
.map(Blob::getName)
.map(
name -> {
int versionDelimiterIndex = name.lastIndexOf('/');
return versionDelimiterIndex != -1
&& versionDelimiterIndex < name.length() - 1
? Optional.of(name.substring(versionDelimiterIndex + 1))
: Optional.<String>empty();
})
.flatMap(Optional::stream)
.map(Integer::parseInt)
.sorted()
.collect(ImmutableList.toImmutableList());
} catch (StorageException e) {
return ImmutableList.of();
}
});
}

@Override
Expand Down Expand Up @@ -291,35 +315,39 @@ private Single<SaveResult> saveArtifactAndReturnBlob(
String appName, String userId, String sessionId, String filename, Part artifact) {
return listVersions(appName, userId, sessionId, filename)
.map(versions -> versions.isEmpty() ? 0 : max(versions) + 1)
.map(
nextVersion -> {
if (artifact.inlineData().isEmpty()) {
throw new IllegalArgumentException("Saveable artifact must have inline data.");
}
.flatMap(
nextVersion ->
Single.fromCallable(
() -> {
if (artifact.inlineData().isEmpty()) {
throw new IllegalArgumentException(
"Saveable artifact must have inline data.");
}

String blobName = getBlobName(appName, userId, sessionId, filename, nextVersion);
BlobId blobId = BlobId.of(bucketName, blobName);
String blobName =
getBlobName(appName, userId, sessionId, filename, nextVersion);
BlobId blobId = BlobId.of(bucketName, blobName);

BlobInfo blobInfo =
BlobInfo.newBuilder(blobId)
.setContentType(artifact.inlineData().get().mimeType().orElse(null))
.build();
BlobInfo blobInfo =
BlobInfo.newBuilder(blobId)
.setContentType(artifact.inlineData().get().mimeType().orElse(null))
.build();

try {
byte[] dataToSave =
artifact
.inlineData()
.get()
.data()
.orElseThrow(
() ->
new IllegalArgumentException(
"Saveable artifact data must be non-empty."));
Blob blob = storageClient.create(blobInfo, dataToSave);
return SaveResult.create(blob, nextVersion);
} catch (StorageException e) {
throw new VerifyException("Failed to save artifact to GCS", e);
}
});
try {
byte[] dataToSave =
artifact
.inlineData()
.get()
.data()
.orElseThrow(
() ->
new IllegalArgumentException(
"Saveable artifact data must be non-empty."));
Blob blob = storageClient.create(blobInfo, dataToSave);
return SaveResult.create(blob, nextVersion);
} catch (StorageException e) {
throw new VerifyException("Failed to save artifact to GCS", e);
}
}));
}
}