diff --git a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/util/LakeFSStorageClient.scala b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/util/LakeFSStorageClient.scala index 63c09f4c30b..d01e820259d 100644 --- a/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/util/LakeFSStorageClient.scala +++ b/common/workflow-core/src/main/scala/org/apache/texera/amber/core/storage/util/LakeFSStorageClient.scala @@ -25,6 +25,7 @@ import io.lakefs.clients.sdk.model._ import org.apache.texera.amber.config.StorageConfig import java.io.{File, FileOutputStream, InputStream} +import java.net.URI import java.nio.file.Files import scala.jdk.CollectionConverters._ @@ -358,4 +359,47 @@ object LakeFSStorageClient { branchesApi.resetBranch(repoName, branchName, resetCreation).execute() } + + /** + * Parse a physical address URI of the form ":///" into (bucket, key). + * + * Expected examples: + * - "s3://my-bucket/path/to/file.csv" + * - "gs://my-bucket/some/prefix/data.json" + * + * @param address URI string in the form ":///" + * @return (bucket, key) where key does not start with "/" + * @throws IllegalArgumentException + * if the address is empty, not a valid URI, missing bucket/host, or missing key/path + */ + def parsePhysicalAddress(address: String): (String, String) = { + val raw = Option(address).getOrElse("").trim + if (raw.isEmpty) + throw new IllegalArgumentException("Address is empty (expected ':///')") + + val uri = + try new URI(raw) + catch { + case e: Exception => + throw new IllegalArgumentException( + s"Invalid address URI: '$raw' (expected ':///')", + e + ) + } + + val bucket = Option(uri.getHost).getOrElse("").trim + if (bucket.isEmpty) + throw new IllegalArgumentException( + s"Invalid address: missing host/bucket in '$raw' (expected ':///')" + ) + + val key = Option(uri.getPath).getOrElse("").stripPrefix("/").trim + if (key.isEmpty) + throw new IllegalArgumentException( + s"Invalid address: missing key/path in '$raw' (expected ':///')" + ) + + (bucket, key) + } + } diff --git a/common/workflow-core/src/main/scala/org/apache/texera/service/util/S3StorageClient.scala b/common/workflow-core/src/main/scala/org/apache/texera/service/util/S3StorageClient.scala index 94007e988e5..b7a66a1bc89 100644 --- a/common/workflow-core/src/main/scala/org/apache/texera/service/util/S3StorageClient.scala +++ b/common/workflow-core/src/main/scala/org/apache/texera/service/util/S3StorageClient.scala @@ -259,4 +259,59 @@ object S3StorageClient { DeleteObjectRequest.builder().bucket(bucketName).key(objectKey).build() ) } + + /** + * Uploads a single part for an in-progress S3 multipart upload. + * + * This method wraps the AWS SDK v2 {@code UploadPart} API: + * it builds an {@link software.amazon.awssdk.services.s3.model.UploadPartRequest} + * and streams the part payload via a {@link software.amazon.awssdk.core.sync.RequestBody}. + * + * Payload handling: + * - If {@code contentLength} is provided, the payload is streamed directly from {@code inputStream} + * using {@code RequestBody.fromInputStream(inputStream, len)}. + * - If {@code contentLength} is {@code None}, the entire {@code inputStream} is read into memory + * ({@code readAllBytes}) and uploaded using {@code RequestBody.fromBytes(bytes)}. + * This is convenient but can be memory-expensive for large parts; prefer providing a known length. + * + * Notes: + * - {@code partNumber} must be in the valid S3 range (typically 1..10,000). + * - The caller is responsible for closing {@code inputStream}. + * - This method is synchronous and will block the calling thread until the upload completes. + * + * @param bucket S3 bucket name. + * @param key Object key (path) being uploaded. + * @param uploadId Multipart upload identifier returned by CreateMultipartUpload. + * @param partNumber 1-based part number for this upload. + * @param inputStream Stream containing the bytes for this part. + * @param contentLength Optional size (in bytes) of this part; provide it to avoid buffering in memory. + * @return The {@link software.amazon.awssdk.services.s3.model.UploadPartResponse}, + * including the part ETag used for completing the multipart upload. + */ + def uploadPartWithRequest( + bucket: String, + key: String, + uploadId: String, + partNumber: Int, + inputStream: InputStream, + contentLength: Option[Long] + ): UploadPartResponse = { + val requestBody: RequestBody = contentLength match { + case Some(len) => RequestBody.fromInputStream(inputStream, len) + case None => + val bytes = inputStream.readAllBytes() + RequestBody.fromBytes(bytes) + } + + val req = UploadPartRequest + .builder() + .bucket(bucket) + .key(key) + .uploadId(uploadId) + .partNumber(partNumber) + .build() + + s3Client.uploadPart(req, requestBody) + } + } diff --git a/file-service/src/main/scala/org/apache/texera/service/resource/DatasetResource.scala b/file-service/src/main/scala/org/apache/texera/service/resource/DatasetResource.scala index 023c4ffc88e..44ce22dfb1d 100644 --- a/file-service/src/main/scala/org/apache/texera/service/resource/DatasetResource.scala +++ b/file-service/src/main/scala/org/apache/texera/service/resource/DatasetResource.scala @@ -54,7 +54,8 @@ import org.apache.texera.service.util.S3StorageClient.{ MINIMUM_NUM_OF_MULTIPART_S3_PART } import org.jooq.{DSLContext, EnumType} - +import org.jooq.impl.DSL +import org.jooq.impl.DSL.{inline => inl} import java.io.{InputStream, OutputStream} import java.net.{HttpURLConnection, URL, URLDecoder} import java.nio.charset.StandardCharsets @@ -65,6 +66,13 @@ import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters._ import scala.jdk.OptionConverters._ +import org.apache.texera.dao.jooq.generated.tables.DatasetUploadSession.DATASET_UPLOAD_SESSION +import org.apache.texera.dao.jooq.generated.tables.DatasetUploadSessionPart.DATASET_UPLOAD_SESSION_PART +import org.jooq.exception.DataAccessException +import software.amazon.awssdk.services.s3.model.UploadPartResponse + +import java.sql.SQLException +import scala.util.Try object DatasetResource { @@ -89,11 +97,11 @@ object DatasetResource { */ private def put(buf: Array[Byte], len: Int, url: String, partNum: Int): String = { val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] - conn.setDoOutput(true); + conn.setDoOutput(true) conn.setRequestMethod("PUT") conn.setFixedLengthStreamingMode(len) val out = conn.getOutputStream - out.write(buf, 0, len); + out.write(buf, 0, len) out.close() val code = conn.getResponseCode @@ -401,7 +409,6 @@ class DatasetResource { e ) } - // delete the directory on S3 if ( S3StorageClient.directoryExists(StorageConfig.lakefsBucketName, dataset.getRepositoryName) @@ -639,138 +646,173 @@ class DatasetResource { @QueryParam("type") operationType: String, @QueryParam("ownerEmail") ownerEmail: String, @QueryParam("datasetName") datasetName: String, - @QueryParam("filePath") encodedUrl: String, - @QueryParam("uploadId") uploadId: Optional[String], + @QueryParam("filePath") filePath: String, @QueryParam("numParts") numParts: Optional[Integer], - payload: Map[ - String, - Any - ], // Expecting {"parts": [...], "physicalAddress": "s3://bucket/path"} @Auth user: SessionUser ): Response = { val uid = user.getUid + val dataset: Dataset = getDatasetBy(ownerEmail, datasetName) + + operationType.toLowerCase match { + case "init" => initMultipartUpload(dataset.getDid, filePath, numParts, uid) + case "finish" => finishMultipartUpload(dataset.getDid, filePath, uid) + case "abort" => abortMultipartUpload(dataset.getDid, filePath, uid) + case _ => + throw new BadRequestException("Invalid type parameter. Use 'init', 'finish', or 'abort'.") + } + } - withTransaction(context) { ctx => - val dataset = context - .select(DATASET.fields: _*) - .from(DATASET) - .leftJoin(USER) - .on(USER.UID.eq(DATASET.OWNER_UID)) - .where(USER.EMAIL.eq(ownerEmail)) - .and(DATASET.NAME.eq(datasetName)) - .fetchOneInto(classOf[Dataset]) - if (dataset == null || !userHasWriteAccess(ctx, dataset.getDid, uid)) { - throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE) - } + @POST + @RolesAllowed(Array("REGULAR", "ADMIN")) + @Consumes(Array(MediaType.APPLICATION_OCTET_STREAM)) + @Path("/multipart-upload/part") + def uploadPart( + @QueryParam("ownerEmail") datasetOwnerEmail: String, + @QueryParam("datasetName") datasetName: String, + @QueryParam("filePath") encodedFilePath: String, + @QueryParam("partNumber") partNumber: Int, + partStream: InputStream, + @Context headers: HttpHeaders, + @Auth user: SessionUser + ): Response = { - // Decode the file path - val repositoryName = dataset.getRepositoryName - val filePath = URLDecoder.decode(encodedUrl, StandardCharsets.UTF_8.name()) + val uid = user.getUid + val dataset: Dataset = getDatasetBy(datasetOwnerEmail, datasetName) + val did = dataset.getDid - operationType.toLowerCase match { - case "init" => - val numPartsValue = numParts.toScala.getOrElse( - throw new BadRequestException("numParts is required for initialization") - ) + if (encodedFilePath == null || encodedFilePath.isEmpty) + throw new BadRequestException("filePath is required") + if (partNumber < 1) + throw new BadRequestException("partNumber must be >= 1") - val presignedResponse = LakeFSStorageClient.initiatePresignedMultipartUploads( - repositoryName, - filePath, - numPartsValue - ) - Response - .ok( - Map( - "uploadId" -> presignedResponse.getUploadId, - "presignedUrls" -> presignedResponse.getPresignedUrls, - "physicalAddress" -> presignedResponse.getPhysicalAddress - ) - ) - .build() + val filePath = validateAndNormalizeFilePathOrThrow( + URLDecoder.decode(encodedFilePath, StandardCharsets.UTF_8.name()) + ) - case "finish" => - val uploadIdValue = uploadId.toScala.getOrElse( - throw new BadRequestException("uploadId is required for completion") - ) + val contentLength = + Option(headers.getHeaderString(HttpHeaders.CONTENT_LENGTH)) + .map(_.trim) + .flatMap(s => Try(s.toLong).toOption) + .filter(_ > 0) + .getOrElse { + throw new BadRequestException("Invalid/Missing Content-Length") + } - // Extract parts from the payload - val partsList = payload.get("parts") match { - case Some(rawList: List[_]) => - try { - rawList.map { - case part: Map[_, _] => - val partMap = part.asInstanceOf[Map[String, Any]] - val partNumber = partMap.get("PartNumber") match { - case Some(i: Int) => i - case Some(s: String) => s.toInt - case _ => throw new BadRequestException("Invalid or missing PartNumber") - } - val eTag = partMap.get("ETag") match { - case Some(s: String) => s - case _ => throw new BadRequestException("Invalid or missing ETag") - } - (partNumber, eTag) - - case _ => - throw new BadRequestException("Each part must be a Map[String, Any]") - } - } catch { - case e: NumberFormatException => - throw new BadRequestException("PartNumber must be an integer", e) - } - - case _ => - throw new BadRequestException("Missing or invalid 'parts' list in payload") - } + withTransaction(context) { ctx => + if (!userHasWriteAccess(ctx, did, uid)) + throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE) - // Extract physical address from payload - val physicalAddress = payload.get("physicalAddress") match { - case Some(address: String) => address - case _ => throw new BadRequestException("Missing physicalAddress in payload") - } + val session = ctx + .selectFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(uid) + .and(DATASET_UPLOAD_SESSION.DID.eq(did)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .fetchOne() - // Complete the multipart upload with parts and physical address - val objectStats = LakeFSStorageClient.completePresignedMultipartUploads( - repositoryName, - filePath, - uploadIdValue, - partsList, - physicalAddress - ) + if (session == null) + throw new NotFoundException("Upload session not found. Call type=init first.") - Response - .ok( - Map( - "message" -> "Multipart upload completed successfully", - "filePath" -> objectStats.getPath - ) + val expectedParts = session.getNumPartsRequested + if (partNumber > expectedParts) { + throw new BadRequestException( + s"$partNumber exceeds the requested parts on init: $expectedParts" + ) + } + + if (partNumber < expectedParts && contentLength < MINIMUM_NUM_OF_MULTIPART_S3_PART) { + throw new BadRequestException( + s"Part $partNumber is too small ($contentLength bytes). " + + s"All non-final parts must be >= $MINIMUM_NUM_OF_MULTIPART_S3_PART bytes." + ) + } + + val physicalAddr = Option(session.getPhysicalAddress).map(_.trim).getOrElse("") + if (physicalAddr.isEmpty) { + throw new WebApplicationException( + "Upload session is missing physicalAddress. Re-init the upload.", + Response.Status.INTERNAL_SERVER_ERROR + ) + } + + val uploadId = session.getUploadId + val (bucket, key) = + try LakeFSStorageClient.parsePhysicalAddress(physicalAddr) + catch { + case e: IllegalArgumentException => + throw new WebApplicationException( + s"Upload session has invalid physicalAddress. Re-init the upload. (${e.getMessage})", + Response.Status.INTERNAL_SERVER_ERROR ) - .build() + } - case "abort" => - val uploadIdValue = uploadId.toScala.getOrElse( - throw new BadRequestException("uploadId is required for abortion") - ) + // Per-part lock: if another request is streaming the same part, fail fast. + val partRow = + try { + ctx + .selectFrom(DATASET_UPLOAD_SESSION_PART) + .where( + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID + .eq(uploadId) + .and(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.eq(partNumber)) + ) + .forUpdate() + .noWait() + .fetchOne() + } catch { + case e: DataAccessException + if Option(e.getCause) + .collect { case s: SQLException => s.getSQLState } + .contains("55P03") => + throw new WebApplicationException( + s"Part $partNumber is already being uploaded", + Response.Status.CONFLICT + ) + } - // Extract physical address from payload - val physicalAddress = payload.get("physicalAddress") match { - case Some(address: String) => address - case _ => throw new BadRequestException("Missing physicalAddress in payload") - } + if (partRow == null) { + // Should not happen if init pre-created rows + throw new WebApplicationException( + s"Part row not initialized for part $partNumber. Re-init the upload.", + Response.Status.INTERNAL_SERVER_ERROR + ) + } - // Abort the multipart upload - LakeFSStorageClient.abortPresignedMultipartUploads( - repositoryName, - filePath, - uploadIdValue, - physicalAddress + // Idempotency: if ETag already set, accept the retry quickly. + val existing = Option(partRow.getEtag).map(_.trim).getOrElse("") + if (existing.isEmpty) { + // Stream to S3 while holding the part lock (prevents concurrent streams for same part) + val response: UploadPartResponse = + S3StorageClient.uploadPartWithRequest( + bucket = bucket, + key = key, + uploadId = uploadId, + partNumber = partNumber, + inputStream = partStream, + contentLength = Some(contentLength) ) - Response.ok(Map("message" -> "Multipart upload aborted successfully")).build() + val etagClean = Option(response.eTag()).map(_.replace("\"", "")).map(_.trim).getOrElse("") + if (etagClean.isEmpty) { + throw new WebApplicationException( + s"Missing ETag returned from S3 for part $partNumber", + Response.Status.INTERNAL_SERVER_ERROR + ) + } - case _ => - throw new BadRequestException("Invalid type parameter. Use 'init', 'finish', or 'abort'.") + ctx + .update(DATASET_UPLOAD_SESSION_PART) + .set(DATASET_UPLOAD_SESSION_PART.ETAG, etagClean) + .where( + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID + .eq(uploadId) + .and(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.eq(partNumber)) + ) + .execute() } + Response.ok().build() } } @@ -1014,9 +1056,8 @@ class DatasetResource { val ownerNode = DatasetFileNode .fromLakeFSRepositoryCommittedObjects( Map( - (user.getEmail, dataset.getName, latestVersion.getName) -> - LakeFSStorageClient - .retrieveObjectsOfVersion(dataset.getRepositoryName, latestVersion.getVersionHash) + (user.getEmail, dataset.getName, latestVersion.getName) -> LakeFSStorageClient + .retrieveObjectsOfVersion(dataset.getRepositoryName, latestVersion.getVersionHash) ) ) .head @@ -1326,4 +1367,379 @@ class DatasetResource { Right(response) } } + + // === Multipart helpers === + + private def getDatasetBy(ownerEmail: String, datasetName: String) = { + val dataset = context + .select(DATASET.fields: _*) + .from(DATASET) + .leftJoin(USER) + .on(USER.UID.eq(DATASET.OWNER_UID)) + .where(USER.EMAIL.eq(ownerEmail)) + .and(DATASET.NAME.eq(datasetName)) + .fetchOneInto(classOf[Dataset]) + if (dataset == null) { + throw new BadRequestException("Dataset not found") + } + dataset + } + + private def validateAndNormalizeFilePathOrThrow(filePath: String): String = { + val path = Option(filePath).getOrElse("").replace("\\", "/") + if ( + path.isEmpty || + path.startsWith("/") || + path.split("/").exists(seg => seg == "." || seg == "..") || + path.exists(ch => ch == 0.toChar || ch < 0x20.toChar || ch == 0x7f.toChar) + ) throw new BadRequestException("Invalid filePath") + path + } + + private def initMultipartUpload( + did: Integer, + encodedFilePath: String, + numParts: Optional[Integer], + uid: Integer + ): Response = { + + withTransaction(context) { ctx => + if (!userHasWriteAccess(ctx, did, uid)) { + throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE) + } + + val dataset = getDatasetByID(ctx, did) + val repositoryName = dataset.getRepositoryName + + val filePath = + validateAndNormalizeFilePathOrThrow( + URLDecoder.decode(encodedFilePath, StandardCharsets.UTF_8.name()) + ) + + val numPartsValue = numParts.toScala.getOrElse { + throw new BadRequestException("numParts is required for initialization") + } + if (numPartsValue < 1 || numPartsValue > MAXIMUM_NUM_OF_MULTIPART_S3_PARTS) { + throw new BadRequestException( + "numParts must be between 1 and " + MAXIMUM_NUM_OF_MULTIPART_S3_PARTS + ) + } + + // Reject if a session already exists + val exists = ctx.fetchExists( + ctx + .selectOne() + .from(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(uid) + .and(DATASET_UPLOAD_SESSION.DID.eq(did)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + ) + if (exists) { + throw new WebApplicationException( + "Upload already in progress for this filePath", + Response.Status.CONFLICT + ) + } + + val presign = LakeFSStorageClient.initiatePresignedMultipartUploads( + repositoryName, + filePath, + numPartsValue + ) + + val uploadIdStr = presign.getUploadId + val physicalAddr = presign.getPhysicalAddress + + // If anything fails after this point, abort LakeFS multipart + try { + val rowsInserted = ctx + .insertInto(DATASET_UPLOAD_SESSION) + .set(DATASET_UPLOAD_SESSION.FILE_PATH, filePath) + .set(DATASET_UPLOAD_SESSION.DID, did) + .set(DATASET_UPLOAD_SESSION.UID, uid) + .set(DATASET_UPLOAD_SESSION.UPLOAD_ID, uploadIdStr) + .set(DATASET_UPLOAD_SESSION.PHYSICAL_ADDRESS, physicalAddr) + .set(DATASET_UPLOAD_SESSION.NUM_PARTS_REQUESTED, numPartsValue) + .onDuplicateKeyIgnore() + .execute() + + if (rowsInserted != 1) { + LakeFSStorageClient.abortPresignedMultipartUploads( + repositoryName, + filePath, + uploadIdStr, + physicalAddr + ) + throw new WebApplicationException( + "Upload already in progress for this filePath", + Response.Status.CONFLICT + ) + } + + // Pre-create part rows 1..numPartsValue with empty ETag. + // This makes per-part locking cheap and deterministic. + + val partNumberSeries = DSL.generateSeries(1, numPartsValue).asTable("gs", "pn") + val partNumberField = partNumberSeries.field("pn", classOf[Integer]) + + ctx + .insertInto( + DATASET_UPLOAD_SESSION_PART, + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID, + DATASET_UPLOAD_SESSION_PART.PART_NUMBER, + DATASET_UPLOAD_SESSION_PART.ETAG + ) + .select( + ctx + .select( + inl(uploadIdStr), + partNumberField, + inl("") // placeholder empty etag + ) + .from(partNumberSeries) + ) + .execute() + + Response.ok().build() + } catch { + case e: Exception => + // rollback will remove session + parts rows; we still must abort LakeFS + try { + LakeFSStorageClient.abortPresignedMultipartUploads( + repositoryName, + filePath, + uploadIdStr, + physicalAddr + ) + } catch { case _: Throwable => () } + throw e + } + } + } + + private def finishMultipartUpload( + did: Integer, + encodedFilePath: String, + uid: Int + ): Response = { + + val filePath = validateAndNormalizeFilePathOrThrow( + URLDecoder.decode(encodedFilePath, StandardCharsets.UTF_8.name()) + ) + + withTransaction(context) { ctx => + if (!userHasWriteAccess(ctx, did, uid)) { + throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE) + } + + val dataset = getDatasetByID(ctx, did) + + // Lock the session so abort/finish don't race each other + val session = + try { + ctx + .selectFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(uid) + .and(DATASET_UPLOAD_SESSION.DID.eq(did)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .forUpdate() + .noWait() + .fetchOne() + } catch { + case e: DataAccessException + if Option(e.getCause) + .collect { case s: SQLException => s.getSQLState } + .contains("55P03") => + throw new WebApplicationException( + "Upload is already being finalized/aborted", + Response.Status.CONFLICT + ) + } + + if (session == null) { + throw new NotFoundException("Upload session not found or already finalized") + } + + val uploadId = session.getUploadId + val expectedParts = session.getNumPartsRequested + + val physicalAddr = Option(session.getPhysicalAddress).map(_.trim).getOrElse("") + if (physicalAddr.isEmpty) { + throw new WebApplicationException( + "Upload session is missing physicalAddress. Re-init the upload.", + Response.Status.INTERNAL_SERVER_ERROR + ) + } + + val total = DSL.count() + val done = + DSL + .count() + .filterWhere(DATASET_UPLOAD_SESSION_PART.ETAG.ne("")) + .as("done") + + val agg = ctx + .select(total.as("total"), done) + .from(DATASET_UPLOAD_SESSION_PART) + .where(DATASET_UPLOAD_SESSION_PART.UPLOAD_ID.eq(uploadId)) + .fetchOne() + + val totalCnt = agg.get("total", classOf[java.lang.Integer]).intValue() + val doneCnt = agg.get("done", classOf[java.lang.Integer]).intValue() + + if (totalCnt != expectedParts) { + throw new WebApplicationException( + s"Part table mismatch: expected $expectedParts rows but found $totalCnt. Re-init the upload.", + Response.Status.INTERNAL_SERVER_ERROR + ) + } + + if (doneCnt != expectedParts) { + val missing = ctx + .select(DATASET_UPLOAD_SESSION_PART.PART_NUMBER) + .from(DATASET_UPLOAD_SESSION_PART) + .where( + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID + .eq(uploadId) + .and(DATASET_UPLOAD_SESSION_PART.ETAG.eq("")) + ) + .orderBy(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.asc()) + .limit(50) + .fetch(DATASET_UPLOAD_SESSION_PART.PART_NUMBER) + .asScala + .toList + + throw new WebApplicationException( + s"Upload incomplete. Some missing ETags for parts are: ${missing.mkString(",")}", + Response.Status.CONFLICT + ) + } + + // Build partsList in order + val partsList: List[(Int, String)] = + ctx + .select(DATASET_UPLOAD_SESSION_PART.PART_NUMBER, DATASET_UPLOAD_SESSION_PART.ETAG) + .from(DATASET_UPLOAD_SESSION_PART) + .where(DATASET_UPLOAD_SESSION_PART.UPLOAD_ID.eq(uploadId)) + .orderBy(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.asc()) + .fetch() + .asScala + .map(r => + ( + r.get(DATASET_UPLOAD_SESSION_PART.PART_NUMBER).intValue(), + r.get(DATASET_UPLOAD_SESSION_PART.ETAG) + ) + ) + .toList + + val objectStats = LakeFSStorageClient.completePresignedMultipartUploads( + dataset.getRepositoryName, + filePath, + uploadId, + partsList, + physicalAddr + ) + + // Cleanup: delete the session; parts are removed by ON DELETE CASCADE + ctx + .deleteFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(uid) + .and(DATASET_UPLOAD_SESSION.DID.eq(did)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .execute() + + Response + .ok( + Map( + "message" -> "Multipart upload completed successfully", + "filePath" -> objectStats.getPath + ) + ) + .build() + } + } + + private def abortMultipartUpload( + did: Integer, + encodedFilePath: String, + uid: Int + ): Response = { + + val filePath = validateAndNormalizeFilePathOrThrow( + URLDecoder.decode(encodedFilePath, StandardCharsets.UTF_8.name()) + ) + + withTransaction(context) { ctx => + if (!userHasWriteAccess(ctx, did, uid)) { + throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE) + } + + val dataset = getDatasetByID(ctx, did) + + val session = + try { + ctx + .selectFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(uid) + .and(DATASET_UPLOAD_SESSION.DID.eq(did)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .forUpdate() + .noWait() + .fetchOne() + } catch { + case e: DataAccessException + if Option(e.getCause) + .collect { case s: SQLException => s.getSQLState } + .contains("55P03") => + throw new WebApplicationException( + "Upload is already being finalized/aborted", + Response.Status.CONFLICT + ) + } + + if (session == null) { + throw new NotFoundException("Upload session not found or already finalized") + } + + val physicalAddr = Option(session.getPhysicalAddress).map(_.trim).getOrElse("") + if (physicalAddr.isEmpty) { + throw new WebApplicationException( + "Upload session is missing physicalAddress. Re-init the upload.", + Response.Status.INTERNAL_SERVER_ERROR + ) + } + + LakeFSStorageClient.abortPresignedMultipartUploads( + dataset.getRepositoryName, + filePath, + session.getUploadId, + physicalAddr + ) + + // Delete session; parts removed via ON DELETE CASCADE + ctx + .deleteFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(uid) + .and(DATASET_UPLOAD_SESSION.DID.eq(did)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .execute() + + Response.ok(Map("message" -> "Multipart upload aborted successfully")).build() + } + } } diff --git a/file-service/src/test/scala/org/apache/texera/service/MockLakeFS.scala b/file-service/src/test/scala/org/apache/texera/service/MockLakeFS.scala index fd1f0b8c903..10c68bd0858 100644 --- a/file-service/src/test/scala/org/apache/texera/service/MockLakeFS.scala +++ b/file-service/src/test/scala/org/apache/texera/service/MockLakeFS.scala @@ -20,11 +20,18 @@ package org.apache.texera.service import com.dimafeng.testcontainers._ +import io.lakefs.clients.sdk.{ApiClient, RepositoriesApi} import org.apache.texera.amber.config.StorageConfig import org.apache.texera.service.util.S3StorageClient import org.scalatest.{BeforeAndAfterAll, Suite} import org.testcontainers.containers.Network import org.testcontainers.utility.DockerImageName +import software.amazon.awssdk.auth.credentials.{AwsBasicCredentials, StaticCredentialsProvider} +import software.amazon.awssdk.regions.Region +import software.amazon.awssdk.services.s3.S3Client +import software.amazon.awssdk.services.s3.S3Configuration + +import java.net.URI /** * Trait to spin up a LakeFS + MinIO + Postgres stack using Testcontainers, @@ -58,9 +65,14 @@ trait MockLakeFS extends ForAllTestContainer with BeforeAndAfterAll { self: Suit s"postgresql://${postgres.username}:${postgres.password}" + s"@${postgres.container.getNetworkAliases.get(0)}:5432/${postgres.databaseName}" + s"?sslmode=disable" + val lakefsUsername = "texera-admin" + + // These are the API credentials created/used during setup. + // In lakeFS, the access key + secret key are used as basic-auth username/password for the API. val lakefsAccessKeyID = "AKIAIOSFOLKFSSAMPLES" val lakefsSecretAccessKey = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY" + val lakefs = GenericContainer( dockerImage = "treeverse/lakefs:1.51", exposedPorts = Seq(8000), @@ -87,11 +99,45 @@ trait MockLakeFS extends ForAllTestContainer with BeforeAndAfterAll { self: Suit def lakefsBaseUrl: String = s"http://${lakefs.host}:${lakefs.mappedPort(8000)}" def minioEndpoint: String = s"http://${minio.host}:${minio.mappedPort(9000)}" + def lakefsApiBasePath: String = s"$lakefsBaseUrl/api/v1" + + // ---- Clients (lazy so they initialize after containers are started) ---- + + lazy val lakefsApiClient: ApiClient = { + val apiClient = new ApiClient() + apiClient.setBasePath(lakefsApiBasePath) + // basic-auth for lakeFS API uses accessKey as username, secretKey as password + apiClient.setUsername(lakefsAccessKeyID) + apiClient.setPassword(lakefsSecretAccessKey) + apiClient + } + + lazy val repositoriesApi: RepositoriesApi = new RepositoriesApi(lakefsApiClient) + + /** + * S3 client instance for testing pointed at MinIO. + * + * Notes: + * - Region can be any value for MinIO, but MUST match what your signing expects. + * so we use that. + * - Path-style is important: http://host:port/bucket/key + */ + lazy val s3Client: S3Client = { + //Temporal credentials for testing purposes only + val creds = AwsBasicCredentials.create("texera_minio", "password") + S3Client + .builder() + .endpointOverride(URI.create(StorageConfig.s3Endpoint)) // set in afterStart() + .region(Region.US_WEST_2) // Required for `.build()`; not important in this test config. + .credentialsProvider(StaticCredentialsProvider.create(creds)) + .serviceConfiguration(S3Configuration.builder().pathStyleAccessEnabled(true).build()) + .build() + } override def afterStart(): Unit = { super.afterStart() - // setup LakeFS + // setup LakeFS (idempotent-ish, but will fail if it truly cannot run) val lakefsSetupResult = lakefs.container.execInContainer( "lakefs", "setup", @@ -103,16 +149,14 @@ trait MockLakeFS extends ForAllTestContainer with BeforeAndAfterAll { self: Suit lakefsSecretAccessKey ) if (lakefsSetupResult.getExitCode != 0) { - throw new RuntimeException( - s"Failed to setup LakeFS: ${lakefsSetupResult.getStderr}" - ) + throw new RuntimeException(s"Failed to setup LakeFS: ${lakefsSetupResult.getStderr}") } // replace storage endpoints in StorageConfig StorageConfig.s3Endpoint = minioEndpoint - StorageConfig.lakefsEndpoint = s"$lakefsBaseUrl/api/v1" + StorageConfig.lakefsEndpoint = lakefsApiBasePath - // create S3 bucket + // create S3 bucket used by lakeFS in tests S3StorageClient.createBucketIfNotExist(StorageConfig.lakefsBucketName) } } diff --git a/file-service/src/test/scala/org/apache/texera/service/resource/DatasetResourceSpec.scala b/file-service/src/test/scala/org/apache/texera/service/resource/DatasetResourceSpec.scala index f526d9d5610..3f72c574861 100644 --- a/file-service/src/test/scala/org/apache/texera/service/resource/DatasetResourceSpec.scala +++ b/file-service/src/test/scala/org/apache/texera/service/resource/DatasetResourceSpec.scala @@ -19,26 +19,66 @@ package org.apache.texera.service.resource -import jakarta.ws.rs.{BadRequestException, ForbiddenException} +import ch.qos.logback.classic.{Level, Logger} +import io.lakefs.clients.sdk.ApiException +import jakarta.ws.rs._ +import jakarta.ws.rs.core.{Cookie, HttpHeaders, MediaType, MultivaluedHashMap, Response} import org.apache.texera.amber.core.storage.util.LakeFSStorageClient import org.apache.texera.auth.SessionUser import org.apache.texera.dao.MockTexeraDB import org.apache.texera.dao.jooq.generated.enums.{PrivilegeEnum, UserRoleEnum} +import org.apache.texera.dao.jooq.generated.tables.DatasetUploadSession.DATASET_UPLOAD_SESSION +import org.apache.texera.dao.jooq.generated.tables.DatasetUploadSessionPart.DATASET_UPLOAD_SESSION_PART import org.apache.texera.dao.jooq.generated.tables.daos.{DatasetDao, UserDao} import org.apache.texera.dao.jooq.generated.tables.pojos.{Dataset, User} import org.apache.texera.service.MockLakeFS -import org.scalatest.BeforeAndAfterAll +import org.jooq.SQLDialect +import org.jooq.impl.DSL +import org.scalatest.tagobjects.Slow +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Tag} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers +import org.slf4j.LoggerFactory + +import java.io.{ByteArrayInputStream, IOException, InputStream} +import java.net.URLEncoder +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} +import java.security.MessageDigest +import java.util.concurrent.CyclicBarrier +import java.util.{Collections, Date, Locale, Optional} +import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.jdk.CollectionConverters._ +import scala.util.Random + +object StressMultipart extends Tag("org.apache.texera.stress.multipart") class DatasetResourceSpec extends AnyFlatSpec with Matchers with MockTexeraDB with MockLakeFS - with BeforeAndAfterAll { + with BeforeAndAfterAll + with BeforeAndAfterEach { + + // ---------- logging (multipart tests can be noisy) ---------- + private var savedLevels: Map[String, Level] = Map.empty + + private def setLoggerLevel(loggerName: String, newLevel: Level): Level = { + val logger = LoggerFactory.getLogger(loggerName).asInstanceOf[Logger] + val prev = logger.getLevel + logger.setLevel(newLevel) + prev + } - private val testUser: User = { + // ---------- execution context (multipart race tests) ---------- + private implicit val ec: ExecutionContext = ExecutionContext.global + + // --------------------------------------------------------------------------- + // Shared fixtures (DatasetResource basic tests) + // --------------------------------------------------------------------------- + private val ownerUser: User = { val user = new User user.setName("test_user") user.setPassword("123") @@ -47,7 +87,7 @@ class DatasetResourceSpec user } - private val testUser2: User = { + private val otherAdminUser: User = { val user = new User user.setName("test_user2") user.setPassword("123") @@ -56,7 +96,17 @@ class DatasetResourceSpec user } - private val testDataset: Dataset = { + // REGULAR user used specifically for multipart "no WRITE access" tests. + private val multipartNoWriteUser: User = { + val user = new User + user.setName("multipart_user2") + user.setPassword("123") + user.setEmail("multipart_user2@test.com") + user.setRole(UserRoleEnum.REGULAR) + user + } + + private val baseDataset: Dataset = { val dataset = new Dataset dataset.setName("test-dataset") dataset.setRepositoryName("test-dataset") @@ -66,29 +116,84 @@ class DatasetResourceSpec dataset } - lazy val datasetDao = new DatasetDao(getDSLContext.configuration()) + // --------------------------------------------------------------------------- + // Multipart fixtures + // --------------------------------------------------------------------------- + private val multipartRepoName: String = + s"multipart-ds-${System.nanoTime()}-${Random.alphanumeric.take(6).mkString.toLowerCase}" + private val multipartDataset: Dataset = { + val dataset = new Dataset + dataset.setName("multipart-ds") + dataset.setRepositoryName(multipartRepoName) + dataset.setIsPublic(true) + dataset.setIsDownloadable(true) + dataset.setDescription("dataset for multipart upload tests") + dataset + } + + // ---------- DAOs / resource ---------- + lazy val datasetDao = new DatasetDao(getDSLContext.configuration()) lazy val datasetResource = new DatasetResource() - lazy val sessionUser = new SessionUser(testUser) - lazy val sessionUser2 = new SessionUser(testUser2) + // ---------- session users ---------- + lazy val sessionUser = new SessionUser(ownerUser) + lazy val sessionUser2 = new SessionUser(otherAdminUser) + // Multipart callers + lazy val multipartOwnerSessionUser = sessionUser + lazy val multipartNoWriteSessionUser = new SessionUser(multipartNoWriteUser) + + // --------------------------------------------------------------------------- + // Lifecycle + // --------------------------------------------------------------------------- override protected def beforeAll(): Unit = { super.beforeAll() // init db initializeDBAndReplaceDSLContext() - // insert test user + // insert users val userDao = new UserDao(getDSLContext.configuration()) - userDao.insert(testUser) - userDao.insert(testUser2) + userDao.insert(ownerUser) + userDao.insert(otherAdminUser) + userDao.insert(multipartNoWriteUser) + + // insert datasets (owned by ownerUser) + baseDataset.setOwnerUid(ownerUser.getUid) + multipartDataset.setOwnerUid(ownerUser.getUid) + + datasetDao.insert(baseDataset) + datasetDao.insert(multipartDataset) + + savedLevels = Map( + "org.apache.http.wire" -> setLoggerLevel("org.apache.http.wire", Level.WARN), + "org.apache.http.headers" -> setLoggerLevel("org.apache.http.headers", Level.WARN) + ) + } + + override protected def beforeEach(): Unit = { + super.beforeEach() + + // Multipart repo must exist for presigned multipart init to succeed. + // If it already exists, ignore 409. + try LakeFSStorageClient.initRepo(multipartDataset.getRepositoryName) + catch { + case e: ApiException if e.getCode == 409 => // ok + } + } - // insert test dataset - testDataset.setOwnerUid(testUser.getUid) - datasetDao.insert(testDataset) + override protected def afterAll(): Unit = { + try shutdownDB() + finally { + try savedLevels.foreach { case (name, prev) => setLoggerLevel(name, prev) } finally super + .afterAll() + } } + // =========================================================================== + // DatasetResourceSpec (original basic tests) + // =========================================================================== "createDataset" should "create dataset successfully if user does not have a dataset with the same name" in { val createDatasetRequest = DatasetResource.CreateDatasetRequest( datasetName = "new-dataset", @@ -142,13 +247,11 @@ class DatasetResourceSpec val dashboardDataset = datasetResource.createDataset(createDatasetRequest, sessionUser) - // Verify the DashboardDataset properties - dashboardDataset.ownerEmail shouldEqual testUser.getEmail + dashboardDataset.ownerEmail shouldEqual ownerUser.getEmail dashboardDataset.accessPrivilege shouldEqual PrivilegeEnum.WRITE dashboardDataset.isOwner shouldBe true dashboardDataset.size shouldEqual 0 - // Verify the underlying dataset properties dashboardDataset.dataset.getName shouldEqual "dashboard-dataset-test" dashboardDataset.dataset.getDescription shouldEqual "test for DashboardDataset properties" dashboardDataset.dataset.getIsPublic shouldBe true @@ -156,51 +259,1073 @@ class DatasetResourceSpec } it should "delete dataset successfully if user owns it" in { - // insert a dataset directly into DB val dataset = new Dataset dataset.setName("delete-ds") dataset.setRepositoryName("delete-ds") dataset.setDescription("for delete test") - dataset.setOwnerUid(testUser.getUid) + dataset.setOwnerUid(ownerUser.getUid) dataset.setIsPublic(true) dataset.setIsDownloadable(true) datasetDao.insert(dataset) - // create repo in LakeFS to match dataset LakeFSStorageClient.initRepo(dataset.getRepositoryName) - // delete via DatasetResource val response = datasetResource.deleteDataset(dataset.getDid, sessionUser) - // assert: response OK and DB no longer contains dataset response.getStatus shouldEqual 200 datasetDao.fetchOneByDid(dataset.getDid) shouldBe null } it should "refuse to delete dataset if not owned by user" in { - // insert a dataset directly into DB val dataset = new Dataset dataset.setName("user1-ds") dataset.setRepositoryName("user1-ds") dataset.setDescription("for forbidden test") - dataset.setOwnerUid(testUser.getUid) + dataset.setOwnerUid(ownerUser.getUid) dataset.setIsPublic(true) dataset.setIsDownloadable(true) datasetDao.insert(dataset) - // create repo in LakeFS to match dataset LakeFSStorageClient.initRepo(dataset.getRepositoryName) - // user2 tries to delete, should throw ForbiddenException assertThrows[ForbiddenException] { datasetResource.deleteDataset(dataset.getDid, sessionUser2) } - // dataset still exists in DB datasetDao.fetchOneByDid(dataset.getDid) should not be null } - override protected def afterAll(): Unit = { - shutdownDB() + // =========================================================================== + // Multipart upload tests (merged in) + // =========================================================================== + + // ---------- SHA-256 Utils ---------- + private def sha256OfChunks(chunks: Seq[Array[Byte]]): Array[Byte] = { + val messageDigest = MessageDigest.getInstance("SHA-256") + chunks.foreach(messageDigest.update) + messageDigest.digest() + } + + private def sha256OfFile(path: java.nio.file.Path): Array[Byte] = { + val messageDigest = MessageDigest.getInstance("SHA-256") + val inputStream = Files.newInputStream(path) + try { + val buffer = new Array[Byte](8192) + var bytesRead = inputStream.read(buffer) + while (bytesRead != -1) { + messageDigest.update(buffer, 0, bytesRead) + bytesRead = inputStream.read(buffer) + } + messageDigest.digest() + } finally inputStream.close() + } + + // ---------- helpers ---------- + private def urlEnc(raw: String): String = + URLEncoder.encode(raw, StandardCharsets.UTF_8.name()) + + /** Minimum part-size rule (S3-style): every part except the LAST must be >= 5 MiB. */ + private val MinNonFinalPartBytes: Int = 5 * 1024 * 1024 + private def minPartBytes(fillByte: Byte): Array[Byte] = + Array.fill[Byte](MinNonFinalPartBytes)(fillByte) + + private def tinyBytes(fillByte: Byte, n: Int = 1): Array[Byte] = + Array.fill[Byte](n)(fillByte) + + /** InputStream that behaves like a mid-flight network drop after N bytes. */ + private def flakyStream( + payload: Array[Byte], + failAfterBytes: Int, + msg: String = "simulated network drop" + ): InputStream = + new InputStream { + private var pos = 0 + override def read(): Int = { + if (pos >= failAfterBytes) throw new IOException(msg) + if (pos >= payload.length) return -1 + val nextByte = payload(pos) & 0xff + pos += 1 + nextByte + } + } + + /** Minimal HttpHeaders impl needed by DatasetResource.uploadPart */ + private def mkHeaders(contentLength: Long): HttpHeaders = + new HttpHeaders { + private val headers = new MultivaluedHashMap[String, String]() + headers.putSingle(HttpHeaders.CONTENT_LENGTH, contentLength.toString) + + override def getHeaderString(name: String): String = headers.getFirst(name) + override def getRequestHeaders = headers + override def getRequestHeader(name: String) = + Option(headers.get(name)).getOrElse(Collections.emptyList[String]()) + + override def getAcceptableMediaTypes = Collections.emptyList[MediaType]() + override def getAcceptableLanguages = Collections.emptyList[Locale]() + override def getMediaType: MediaType = null + override def getLanguage: Locale = null + override def getCookies = Collections.emptyMap[String, Cookie]() + override def getDate: Date = null + override def getLength: Int = contentLength.toInt + } + + private def mkHeadersMissingContentLength: HttpHeaders = + new HttpHeaders { + private val headers = new MultivaluedHashMap[String, String]() + override def getHeaderString(name: String): String = null + override def getRequestHeaders = headers + override def getRequestHeader(name: String) = Collections.emptyList[String]() + override def getAcceptableMediaTypes = Collections.emptyList[MediaType]() + override def getAcceptableLanguages = Collections.emptyList[Locale]() + override def getMediaType: MediaType = null + override def getLanguage: Locale = null + override def getCookies = Collections.emptyMap[String, Cookie]() + override def getDate: Date = null + override def getLength: Int = -1 + } + + private def uniqueFilePath(prefix: String): String = + s"$prefix/${System.nanoTime()}-${Random.alphanumeric.take(8).mkString}.bin" + + private def initUpload( + filePath: String, + numParts: Int, + user: SessionUser = multipartOwnerSessionUser + ): Response = + datasetResource.multipartUpload( + "init", + ownerUser.getEmail, + multipartDataset.getName, + urlEnc(filePath), + Optional.of(numParts), + user + ) + + private def finishUpload( + filePath: String, + user: SessionUser = multipartOwnerSessionUser + ): Response = + datasetResource.multipartUpload( + "finish", + ownerUser.getEmail, + multipartDataset.getName, + urlEnc(filePath), + Optional.empty(), + user + ) + + private def abortUpload( + filePath: String, + user: SessionUser = multipartOwnerSessionUser + ): Response = + datasetResource.multipartUpload( + "abort", + ownerUser.getEmail, + multipartDataset.getName, + urlEnc(filePath), + Optional.empty(), + user + ) + + private def uploadPart( + filePath: String, + partNumber: Int, + bytes: Array[Byte], + user: SessionUser = multipartOwnerSessionUser, + contentLengthOverride: Option[Long] = None, + missingContentLength: Boolean = false + ): Response = { + val hdrs = + if (missingContentLength) mkHeadersMissingContentLength + else mkHeaders(contentLengthOverride.getOrElse(bytes.length.toLong)) + + datasetResource.uploadPart( + ownerUser.getEmail, + multipartDataset.getName, + urlEnc(filePath), + partNumber, + new ByteArrayInputStream(bytes), + hdrs, + user + ) + } + + private def uploadPartWithStream( + filePath: String, + partNumber: Int, + stream: InputStream, + contentLength: Long, + user: SessionUser = multipartOwnerSessionUser + ): Response = + datasetResource.uploadPart( + ownerUser.getEmail, + multipartDataset.getName, + urlEnc(filePath), + partNumber, + stream, + mkHeaders(contentLength), + user + ) + + private def fetchSession(filePath: String) = + getDSLContext + .selectFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(ownerUser.getUid) + .and(DATASET_UPLOAD_SESSION.DID.eq(multipartDataset.getDid)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .fetchOne() + + private def fetchPartRows(uploadId: String) = + getDSLContext + .selectFrom(DATASET_UPLOAD_SESSION_PART) + .where(DATASET_UPLOAD_SESSION_PART.UPLOAD_ID.eq(uploadId)) + .fetch() + .asScala + .toList + + private def fetchUploadIdOrFail(filePath: String): String = { + val sessionRecord = fetchSession(filePath) + sessionRecord should not be null + sessionRecord.getUploadId + } + + private def assertPlaceholdersCreated(uploadId: String, expectedParts: Int): Unit = { + val rows = fetchPartRows(uploadId).sortBy(_.getPartNumber) + rows.size shouldEqual expectedParts + rows.head.getPartNumber shouldEqual 1 + rows.last.getPartNumber shouldEqual expectedParts + rows.foreach { r => + r.getEtag should not be null + r.getEtag shouldEqual "" // placeholder convention + } + } + + private def assertStatus(ex: WebApplicationException, status: Int): Unit = + ex.getResponse.getStatus shouldEqual status + + // --------------------------------------------------------------------------- + // INIT TESTS + // --------------------------------------------------------------------------- + "multipart-upload?type=init" should "create an upload session row + precreate part placeholders (happy path)" in { + val filePath = uniqueFilePath("init-happy") + val resp = initUpload(filePath, numParts = 3) + + resp.getStatus shouldEqual 200 + + val sessionRecord = fetchSession(filePath) + sessionRecord should not be null + sessionRecord.getNumPartsRequested shouldEqual 3 + sessionRecord.getUploadId should not be null + sessionRecord.getPhysicalAddress should not be null + + assertPlaceholdersCreated(sessionRecord.getUploadId, expectedParts = 3) + } + + it should "reject missing numParts" in { + val filePath = uniqueFilePath("init-missing-numparts") + val ex = intercept[BadRequestException] { + datasetResource.multipartUpload( + "init", + ownerUser.getEmail, + multipartDataset.getName, + urlEnc(filePath), + Optional.empty(), + multipartOwnerSessionUser + ) + } + assertStatus(ex, 400) + } + + it should "reject invalid numParts (0, negative, too large)" in { + val filePath = uniqueFilePath("init-bad-numparts") + assertStatus(intercept[BadRequestException] { initUpload(filePath, 0) }, 400) + assertStatus(intercept[BadRequestException] { initUpload(filePath, -1) }, 400) + assertStatus(intercept[BadRequestException] { initUpload(filePath, 1000000000) }, 400) + } + + it should "reject invalid filePath (empty, absolute, '.', '..', control chars)" in { + assertStatus(intercept[BadRequestException] { initUpload("./nope.bin", 2) }, 400) + assertStatus(intercept[BadRequestException] { initUpload("/absolute.bin", 2) }, 400) + assertStatus(intercept[BadRequestException] { initUpload("a/./b.bin", 2) }, 400) + + assertStatus(intercept[BadRequestException] { initUpload("../escape.bin", 2) }, 400) + assertStatus(intercept[BadRequestException] { initUpload("a/../escape.bin", 2) }, 400) + + assertStatus( + intercept[BadRequestException] { + initUpload(s"a/${0.toChar}b.bin", 2) + }, + 400 + ) + } + + it should "reject invalid type parameter" in { + val filePath = uniqueFilePath("init-bad-type") + val ex = intercept[BadRequestException] { + datasetResource.multipartUpload( + "not-a-real-type", + ownerUser.getEmail, + multipartDataset.getName, + urlEnc(filePath), + Optional.empty(), + multipartOwnerSessionUser + ) + } + assertStatus(ex, 400) + } + + it should "reject init when caller lacks WRITE access" in { + val filePath = uniqueFilePath("init-forbidden") + val ex = intercept[ForbiddenException] { + initUpload(filePath, numParts = 2, user = multipartNoWriteSessionUser) + } + assertStatus(ex, 403) + } + + it should "handle init race: exactly one succeeds, one gets 409 CONFLICT" in { + val filePath = uniqueFilePath("init-race") + val barrier = new CyclicBarrier(2) + + def callInit(): Either[Throwable, Response] = + try { + barrier.await() + Right(initUpload(filePath, numParts = 2)) + } catch { + case t: Throwable => Left(t) + } + + val future1 = Future(callInit()) + val future2 = Future(callInit()) + val results = Await.result(Future.sequence(Seq(future1, future2)), 30.seconds) + + val oks = results.collect { case Right(r) if r.getStatus == 200 => r } + val fails = results.collect { case Left(t) => t } + + oks.size shouldEqual 1 + fails.size shouldEqual 1 + + fails.head match { + case e: WebApplicationException => assertStatus(e, 409) + case other => + fail( + s"Expected WebApplicationException(CONFLICT), got: ${other.getClass} / ${other.getMessage}" + ) + } + + val sessionRecord = fetchSession(filePath) + sessionRecord should not be null + assertPlaceholdersCreated(sessionRecord.getUploadId, expectedParts = 2) + } + + it should "reject sequential double init with 409 CONFLICT" in { + val filePath = uniqueFilePath("init-double") + initUpload(filePath, numParts = 2).getStatus shouldEqual 200 + + val ex = intercept[WebApplicationException] { initUpload(filePath, numParts = 2) } + assertStatus(ex, 409) + } + + // --------------------------------------------------------------------------- + // PART UPLOAD TESTS + // --------------------------------------------------------------------------- + "multipart-upload/part" should "reject uploadPart if init was not called" in { + val filePath = uniqueFilePath("part-no-init") + val ex = intercept[NotFoundException] { + uploadPart(filePath, partNumber = 1, bytes = Array[Byte](1, 2, 3)) + } + assertStatus(ex, 404) + } + + it should "reject missing/invalid Content-Length" in { + val filePath = uniqueFilePath("part-bad-cl") + initUpload(filePath, numParts = 2) + + assertStatus( + intercept[BadRequestException] { + uploadPart( + filePath, + partNumber = 1, + bytes = Array[Byte](1, 2, 3), + missingContentLength = true + ) + }, + 400 + ) + + assertStatus( + intercept[BadRequestException] { + uploadPart( + filePath, + partNumber = 1, + bytes = Array[Byte](1, 2, 3), + contentLengthOverride = Some(0L) + ) + }, + 400 + ) + + assertStatus( + intercept[BadRequestException] { + uploadPart( + filePath, + partNumber = 1, + bytes = Array[Byte](1, 2, 3), + contentLengthOverride = Some(-5L) + ) + }, + 400 + ) + } + + it should "reject null/empty filePath param early without depending on error text" in { + val httpHeaders = mkHeaders(1L) + + val ex1 = intercept[BadRequestException] { + datasetResource.uploadPart( + ownerUser.getEmail, + multipartDataset.getName, + null, + 1, + new ByteArrayInputStream(Array.emptyByteArray), + httpHeaders, + multipartOwnerSessionUser + ) + } + assertStatus(ex1, 400) + + val ex2 = intercept[BadRequestException] { + datasetResource.uploadPart( + ownerUser.getEmail, + multipartDataset.getName, + "", + 1, + new ByteArrayInputStream(Array.emptyByteArray), + httpHeaders, + multipartOwnerSessionUser + ) + } + assertStatus(ex2, 400) + } + + it should "reject invalid partNumber (< 1) and partNumber > requested" in { + val filePath = uniqueFilePath("part-bad-pn") + initUpload(filePath, numParts = 2) + + assertStatus( + intercept[BadRequestException] { + uploadPart(filePath, partNumber = 0, bytes = tinyBytes(1.toByte)) + }, + 400 + ) + + assertStatus( + intercept[BadRequestException] { + uploadPart(filePath, partNumber = 3, bytes = minPartBytes(2.toByte)) + }, + 400 + ) + } + + it should "reject a non-final part smaller than the minimum size (without checking message)" in { + val filePath = uniqueFilePath("part-too-small-nonfinal") + initUpload(filePath, numParts = 2) + + val ex = intercept[BadRequestException] { + uploadPart(filePath, partNumber = 1, bytes = tinyBytes(1.toByte)) + } + assertStatus(ex, 400) + + val uploadId = fetchUploadIdOrFail(filePath) + fetchPartRows(uploadId).find(_.getPartNumber == 1).get.getEtag shouldEqual "" + } + + it should "upload a part successfully and persist its ETag into DATASET_UPLOAD_SESSION_PART" in { + val filePath = uniqueFilePath("part-happy-db") + initUpload(filePath, numParts = 2) + + val uploadId = fetchUploadIdOrFail(filePath) + fetchPartRows(uploadId).find(_.getPartNumber == 1).get.getEtag shouldEqual "" + + val bytes = minPartBytes(7.toByte) + uploadPart(filePath, partNumber = 1, bytes = bytes).getStatus shouldEqual 200 + + val after = fetchPartRows(uploadId).find(_.getPartNumber == 1).get + after.getEtag should not equal "" + } + + it should "allow retrying the same part sequentially (no duplicates, etag ends non-empty)" in { + val filePath = uniqueFilePath("part-retry") + initUpload(filePath, numParts = 2) + val uploadId = fetchUploadIdOrFail(filePath) + + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + uploadPart(filePath, 1, minPartBytes(2.toByte)).getStatus shouldEqual 200 + + val rows = fetchPartRows(uploadId).filter(_.getPartNumber == 1) + rows.size shouldEqual 1 + rows.head.getEtag should not equal "" + } + + it should "apply per-part locking: return 409 if that part row is locked by another uploader" in { + val filePath = uniqueFilePath("part-lock") + initUpload(filePath, numParts = 2) + val uploadId = fetchUploadIdOrFail(filePath) + + val connectionProvider = getDSLContext.configuration().connectionProvider() + val connection = connectionProvider.acquire() + connection.setAutoCommit(false) + + try { + val locking = DSL.using(connection, SQLDialect.POSTGRES) + locking + .selectFrom(DATASET_UPLOAD_SESSION_PART) + .where( + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID + .eq(uploadId) + .and(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.eq(1)) + ) + .forUpdate() + .fetchOne() + + val ex = intercept[WebApplicationException] { + uploadPart(filePath, 1, minPartBytes(1.toByte)) + } + assertStatus(ex, 409) + } finally { + connection.rollback() + connectionProvider.release(connection) + } + + uploadPart(filePath, 1, minPartBytes(3.toByte)).getStatus shouldEqual 200 + } + + it should "not block other parts: locking part 1 does not prevent uploading part 2" in { + val filePath = uniqueFilePath("part-lock-other-part") + initUpload(filePath, numParts = 2) + val uploadId = fetchUploadIdOrFail(filePath) + + val connectionProvider = getDSLContext.configuration().connectionProvider() + val connection = connectionProvider.acquire() + connection.setAutoCommit(false) + + try { + val locking = DSL.using(connection, SQLDialect.POSTGRES) + locking + .selectFrom(DATASET_UPLOAD_SESSION_PART) + .where( + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID + .eq(uploadId) + .and(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.eq(1)) + ) + .forUpdate() + .fetchOne() + + uploadPart(filePath, 2, tinyBytes(9.toByte)).getStatus shouldEqual 200 + } finally { + connection.rollback() + connectionProvider.release(connection) + } + } + + it should "reject uploadPart when caller lacks WRITE access" in { + val filePath = uniqueFilePath("part-forbidden") + initUpload(filePath, numParts = 2) + + val ex = intercept[ForbiddenException] { + uploadPart(filePath, 1, minPartBytes(1.toByte), user = multipartNoWriteSessionUser) + } + assertStatus(ex, 403) + } + + // --------------------------------------------------------------------------- + // FINISH TESTS + // --------------------------------------------------------------------------- + "multipart-upload?type=finish" should "reject finish if init was not called" in { + val filePath = uniqueFilePath("finish-no-init") + val ex = intercept[NotFoundException] { finishUpload(filePath) } + assertStatus(ex, 404) + } + + it should "reject finish when no parts were uploaded (all placeholders empty) without checking messages" in { + val filePath = uniqueFilePath("finish-no-parts") + initUpload(filePath, numParts = 2) + + val ex = intercept[WebApplicationException] { finishUpload(filePath) } + assertStatus(ex, 409) + + fetchSession(filePath) should not be null + } + + it should "reject finish when some parts are missing (etag empty treated as missing)" in { + val filePath = uniqueFilePath("finish-missing") + initUpload(filePath, numParts = 3) + + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + + val ex = intercept[WebApplicationException] { finishUpload(filePath) } + assertStatus(ex, 409) + + val uploadId = fetchUploadIdOrFail(filePath) + fetchPartRows(uploadId).find(_.getPartNumber == 2).get.getEtag shouldEqual "" + fetchPartRows(uploadId).find(_.getPartNumber == 3).get.getEtag shouldEqual "" + } + + it should "reject finish when extra part rows exist in DB (bypass endpoint) without checking messages" in { + val filePath = uniqueFilePath("finish-extra-db") + initUpload(filePath, numParts = 2) + + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + uploadPart(filePath, 2, tinyBytes(2.toByte)).getStatus shouldEqual 200 + + val sessionRecord = fetchSession(filePath) + val uploadId = sessionRecord.getUploadId + + getDSLContext + .insertInto(DATASET_UPLOAD_SESSION_PART) + .set(DATASET_UPLOAD_SESSION_PART.UPLOAD_ID, uploadId) + .set(DATASET_UPLOAD_SESSION_PART.PART_NUMBER, Integer.valueOf(3)) + .set(DATASET_UPLOAD_SESSION_PART.ETAG, "bogus-etag") + .execute() + + val ex = intercept[WebApplicationException] { finishUpload(filePath) } + assertStatus(ex, 500) + + fetchSession(filePath) should not be null + fetchPartRows(uploadId).nonEmpty shouldEqual true + } + + it should "finish successfully when all parts have non-empty etags; delete session + part rows" in { + val filePath = uniqueFilePath("finish-happy") + initUpload(filePath, numParts = 3) + + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + uploadPart(filePath, 2, minPartBytes(2.toByte)).getStatus shouldEqual 200 + uploadPart(filePath, 3, tinyBytes(3.toByte)).getStatus shouldEqual 200 + + val uploadId = fetchUploadIdOrFail(filePath) + + val resp = finishUpload(filePath) + resp.getStatus shouldEqual 200 + + fetchSession(filePath) shouldBe null + fetchPartRows(uploadId) shouldBe empty + } + + it should "be idempotent-ish: second finish should return NotFound after successful finish" in { + val filePath = uniqueFilePath("finish-twice") + initUpload(filePath, numParts = 1) + uploadPart(filePath, 1, tinyBytes(1.toByte)).getStatus shouldEqual 200 + + finishUpload(filePath).getStatus shouldEqual 200 + + val ex = intercept[NotFoundException] { finishUpload(filePath) } + assertStatus(ex, 404) + } + + it should "reject finish when caller lacks WRITE access" in { + val filePath = uniqueFilePath("finish-forbidden") + initUpload(filePath, numParts = 1) + uploadPart(filePath, 1, tinyBytes(1.toByte)).getStatus shouldEqual 200 + + val ex = intercept[ForbiddenException] { + finishUpload(filePath, user = multipartNoWriteSessionUser) + } + assertStatus(ex, 403) + } + + it should "return 409 CONFLICT if the session row is locked by another finalizer/aborter" in { + val filePath = uniqueFilePath("finish-lock-race") + initUpload(filePath, numParts = 1) + uploadPart(filePath, 1, tinyBytes(1.toByte)).getStatus shouldEqual 200 + + val connectionProvider = getDSLContext.configuration().connectionProvider() + val connection = connectionProvider.acquire() + connection.setAutoCommit(false) + + try { + val locking = DSL.using(connection, SQLDialect.POSTGRES) + locking + .selectFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(ownerUser.getUid) + .and(DATASET_UPLOAD_SESSION.DID.eq(multipartDataset.getDid)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .forUpdate() + .fetchOne() + + val ex = intercept[WebApplicationException] { finishUpload(filePath) } + assertStatus(ex, 409) + } finally { + connection.rollback() + connectionProvider.release(connection) + } + } + + // --------------------------------------------------------------------------- + // ABORT TESTS + // --------------------------------------------------------------------------- + "multipart-upload?type=abort" should "reject abort if init was not called" in { + val filePath = uniqueFilePath("abort-no-init") + val ex = intercept[NotFoundException] { abortUpload(filePath) } + assertStatus(ex, 404) + } + + it should "abort successfully; delete session + part rows" in { + val filePath = uniqueFilePath("abort-happy") + initUpload(filePath, numParts = 2) + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + + val uploadId = fetchUploadIdOrFail(filePath) + + abortUpload(filePath).getStatus shouldEqual 200 + + fetchSession(filePath) shouldBe null + fetchPartRows(uploadId) shouldBe empty + } + + it should "reject abort when caller lacks WRITE access" in { + val filePath = uniqueFilePath("abort-forbidden") + initUpload(filePath, numParts = 1) + + val ex = intercept[ForbiddenException] { + abortUpload(filePath, user = multipartNoWriteSessionUser) + } + assertStatus(ex, 403) + } + + it should "return 409 CONFLICT if the session row is locked by another finalizer/aborter" in { + val filePath = uniqueFilePath("abort-lock-race") + initUpload(filePath, numParts = 1) + + val connectionProvider = getDSLContext.configuration().connectionProvider() + val connection = connectionProvider.acquire() + connection.setAutoCommit(false) + + try { + val locking = DSL.using(connection, SQLDialect.POSTGRES) + locking + .selectFrom(DATASET_UPLOAD_SESSION) + .where( + DATASET_UPLOAD_SESSION.UID + .eq(ownerUser.getUid) + .and(DATASET_UPLOAD_SESSION.DID.eq(multipartDataset.getDid)) + .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath)) + ) + .forUpdate() + .fetchOne() + + val ex = intercept[WebApplicationException] { abortUpload(filePath) } + assertStatus(ex, 409) + } finally { + connection.rollback() + connectionProvider.release(connection) + } + } + + it should "be consistent: abort after finish should return NotFound" in { + val filePath = uniqueFilePath("abort-after-finish") + initUpload(filePath, numParts = 1) + uploadPart(filePath, 1, tinyBytes(1.toByte)).getStatus shouldEqual 200 + + finishUpload(filePath).getStatus shouldEqual 200 + + val ex = intercept[NotFoundException] { abortUpload(filePath) } + assertStatus(ex, 404) + } + + // --------------------------------------------------------------------------- + // FAILURE / RESILIENCE (still unit tests; simulated failures) + // --------------------------------------------------------------------------- + "multipart upload implementation" should "release locks and keep DB consistent if the incoming stream fails mid-upload (simulated network drop)" in { + val filePath = uniqueFilePath("netfail-upload-stream") + initUpload(filePath, numParts = 2).getStatus shouldEqual 200 + val uploadId = fetchUploadIdOrFail(filePath) + + val payload = minPartBytes(5.toByte) + + val flaky = new InputStream { + private var pos = 0 + override def read(): Int = { + if (pos >= 1024) throw new IOException("simulated network drop") + val b = payload(pos) & 0xff + pos += 1 + b + } + } + + intercept[Throwable] { + uploadPartWithStream( + filePath, + partNumber = 1, + stream = flaky, + contentLength = payload.length.toLong + ) + } + + fetchPartRows(uploadId).find(_.getPartNumber == 1).get.getEtag shouldEqual "" + + uploadPart(filePath, 1, payload).getStatus shouldEqual 200 + fetchPartRows(uploadId).find(_.getPartNumber == 1).get.getEtag should not equal "" + } + + it should "not delete session/parts if finalize fails downstream (simulate by corrupting an ETag)" in { + val filePath = uniqueFilePath("netfail-finish") + initUpload(filePath, numParts = 2).getStatus shouldEqual 200 + + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + uploadPart(filePath, 2, tinyBytes(2.toByte)).getStatus shouldEqual 200 + + val uploadId = fetchUploadIdOrFail(filePath) + + getDSLContext + .update(DATASET_UPLOAD_SESSION_PART) + .set(DATASET_UPLOAD_SESSION_PART.ETAG, "definitely-not-a-real-etag") + .where( + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID + .eq(uploadId) + .and(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.eq(1)) + ) + .execute() + + intercept[Throwable] { finishUpload(filePath) } + + fetchSession(filePath) should not be null + fetchPartRows(uploadId).nonEmpty shouldEqual true + } + + it should "allow abort + re-init after part 1 succeeded but part 2 drops mid-flight; then complete successfully" in { + val filePath = uniqueFilePath("reinit-after-part2-drop") + + initUpload(filePath, numParts = 2).getStatus shouldEqual 200 + val uploadId1 = fetchUploadIdOrFail(filePath) + + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + + val bytesPart2 = Array.fill[Byte](1024 * 1024)(2.toByte) + intercept[Throwable] { + uploadPartWithStream( + filePath, + partNumber = 2, + stream = flakyStream(bytesPart2, failAfterBytes = 4096), + contentLength = bytesPart2.length.toLong + ) + } + + abortUpload(filePath).getStatus shouldEqual 200 + fetchSession(filePath) shouldBe null + fetchPartRows(uploadId1) shouldBe empty + + initUpload(filePath, numParts = 2).getStatus shouldEqual 200 + uploadPart(filePath, 1, minPartBytes(3.toByte)).getStatus shouldEqual 200 + uploadPart(filePath, 2, tinyBytes(4.toByte, n = 123)).getStatus shouldEqual 200 + finishUpload(filePath).getStatus shouldEqual 200 + fetchSession(filePath) shouldBe null + } + + it should "allow re-upload after failures: (1) part1 drop, (2) part2 drop, (3) finalize failure; each followed by abort + re-init + success" in { + def abortAndAssertClean(filePath: String, uploadId: String): Unit = { + abortUpload(filePath).getStatus shouldEqual 200 + fetchSession(filePath) shouldBe null + fetchPartRows(uploadId) shouldBe empty + } + + def reinitAndFinishHappy(filePath: String): Unit = { + initUpload(filePath, numParts = 2).getStatus shouldEqual 200 + uploadPart(filePath, 1, minPartBytes(7.toByte)).getStatus shouldEqual 200 + uploadPart(filePath, 2, tinyBytes(8.toByte, n = 321)).getStatus shouldEqual 200 + finishUpload(filePath).getStatus shouldEqual 200 + fetchSession(filePath) shouldBe null + } + + withClue("scenario (1): part1 mid-flight drop") { + val filePath = uniqueFilePath("reupload-part1-drop") + initUpload(filePath, numParts = 2).getStatus shouldEqual 200 + val uploadId = fetchUploadIdOrFail(filePath) + + val p1 = minPartBytes(5.toByte) + intercept[Throwable] { + uploadPartWithStream( + filePath, + partNumber = 1, + stream = flakyStream(p1, failAfterBytes = 4096), + contentLength = p1.length.toLong + ) + } + + fetchPartRows(uploadId).find(_.getPartNumber == 1).get.getEtag shouldEqual "" + + abortAndAssertClean(filePath, uploadId) + reinitAndFinishHappy(filePath) + } + + withClue("scenario (2): part2 mid-flight drop") { + val filePath = uniqueFilePath("reupload-part2-drop") + initUpload(filePath, numParts = 2).getStatus shouldEqual 200 + val uploadId = fetchUploadIdOrFail(filePath) + + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + val bytesPart2 = Array.fill[Byte](1024 * 1024)(2.toByte) + intercept[Throwable] { + uploadPartWithStream( + filePath, + partNumber = 2, + stream = flakyStream(bytesPart2, failAfterBytes = 4096), + contentLength = bytesPart2.length.toLong + ) + } + + abortAndAssertClean(filePath, uploadId) + reinitAndFinishHappy(filePath) + } + + withClue("scenario (3): finalize failure then re-upload") { + val filePath = uniqueFilePath("reupload-finalize-fail") + initUpload(filePath, numParts = 2).getStatus shouldEqual 200 + + uploadPart(filePath, 1, minPartBytes(1.toByte)).getStatus shouldEqual 200 + uploadPart(filePath, 2, tinyBytes(2.toByte)).getStatus shouldEqual 200 + + val uploadId = fetchUploadIdOrFail(filePath) + getDSLContext + .update(DATASET_UPLOAD_SESSION_PART) + .set(DATASET_UPLOAD_SESSION_PART.ETAG, "definitely-not-a-real-etag") + .where( + DATASET_UPLOAD_SESSION_PART.UPLOAD_ID + .eq(uploadId) + .and(DATASET_UPLOAD_SESSION_PART.PART_NUMBER.eq(1)) + ) + .execute() + + intercept[Throwable] { finishUpload(filePath) } + fetchSession(filePath) should not be null + fetchPartRows(uploadId).nonEmpty shouldEqual true + + abortAndAssertClean(filePath, uploadId) + reinitAndFinishHappy(filePath) + } + } + + // --------------------------------------------------------------------------- + // CORRUPTION CHECKS + // --------------------------------------------------------------------------- + it should "upload without corruption (sha256 matches final object)" in { + val filePath = uniqueFilePath("sha256-positive") + initUpload(filePath, numParts = 3).getStatus shouldEqual 200 + + val part1 = minPartBytes(1.toByte) + val part2 = minPartBytes(2.toByte) + val part3 = Array.fill[Byte](123)(3.toByte) + + uploadPart(filePath, 1, part1).getStatus shouldEqual 200 + uploadPart(filePath, 2, part2).getStatus shouldEqual 200 + uploadPart(filePath, 3, part3).getStatus shouldEqual 200 + + finishUpload(filePath).getStatus shouldEqual 200 + + val expected = sha256OfChunks(Seq(part1, part2, part3)) + + val repoName = multipartDataset.getRepositoryName + val ref = "main" + val downloaded = LakeFSStorageClient.getFileFromRepo(repoName, ref, filePath) + + val got = sha256OfFile(Paths.get(downloaded.toURI)) + got.toSeq shouldEqual expected.toSeq + } + + it should "detect corruption (sha256 mismatch when a part is altered)" in { + val filePath = uniqueFilePath("sha256-negative") + initUpload(filePath, numParts = 3).getStatus shouldEqual 200 + + val part1 = minPartBytes(1.toByte) + val part2 = minPartBytes(2.toByte) + val part3 = Array.fill[Byte](123)(3.toByte) + + val intendedHash = sha256OfChunks(Seq(part1, part2, part3)) + + val part2corrupt = part2.clone() + part2corrupt(0) = (part2corrupt(0) ^ 0x01).toByte + + uploadPart(filePath, 1, part1).getStatus shouldEqual 200 + uploadPart(filePath, 2, part2corrupt).getStatus shouldEqual 200 + uploadPart(filePath, 3, part3).getStatus shouldEqual 200 + + finishUpload(filePath).getStatus shouldEqual 200 + + val repoName = multipartDataset.getRepositoryName + val ref = "main" + val downloaded = LakeFSStorageClient.getFileFromRepo(repoName, ref, filePath) + + val gotHash = sha256OfFile(Paths.get(downloaded.toURI)) + gotHash.toSeq should not equal intendedHash.toSeq + + val corruptHash = sha256OfChunks(Seq(part1, part2corrupt, part3)) + gotHash.toSeq shouldEqual corruptHash.toSeq + } + + // --------------------------------------------------------------------------- + // STRESS / SOAK TESTS (tagged) + // --------------------------------------------------------------------------- + it should "survive 2 concurrent multipart uploads (fan-out)" taggedAs (StressMultipart, Slow) in { + val parallelUploads = 2 + val maxParts = 2 + + def oneUpload(i: Int): Future[Unit] = + Future { + val filePath = uniqueFilePath(s"stress-$i") + val numParts = 2 + Random.nextInt(maxParts - 1) + + initUpload(filePath, numParts).getStatus shouldEqual 200 + + val sharedMin = minPartBytes((i % 127).toByte) + val partFuts = (1 to numParts).map { partN => + Future { + val bytes = + if (partN < numParts) sharedMin + else tinyBytes((partN % 127).toByte, n = 1024) + uploadPart(filePath, partN, bytes).getStatus shouldEqual 200 + } + } + + Await.result(Future.sequence(partFuts), 60.seconds) + + finishUpload(filePath).getStatus shouldEqual 200 + fetchSession(filePath) shouldBe null + } + + val all = Future.sequence((1 to parallelUploads).map(oneUpload)) + Await.result(all, 180.seconds) + } + + it should "throttle concurrent uploads of the SAME part via per-part locks" taggedAs (StressMultipart, Slow) in { + val filePath = uniqueFilePath("stress-same-part") + initUpload(filePath, numParts = 2).getStatus shouldEqual 200 + + val contenders = 2 + val barrier = new CyclicBarrier(contenders) + + def tryUploadStatus(): Future[Int] = + Future { + barrier.await() + try { + uploadPart(filePath, 1, minPartBytes(7.toByte)).getStatus + } catch { + case e: WebApplicationException => e.getResponse.getStatus + } + } + + val statuses = + Await.result(Future.sequence((1 to contenders).map(_ => tryUploadStatus())), 60.seconds) + + statuses.foreach { s => s should (be(200) or be(409)) } + statuses.count(_ == 200) should be >= 1 + + val uploadId = fetchUploadIdOrFail(filePath) + val part1 = fetchPartRows(uploadId).find(_.getPartNumber == 1).get + part1.getEtag.trim should not be "" } } diff --git a/frontend/src/app/dashboard/component/user/user-dataset/user-dataset-explorer/dataset-detail.component.ts b/frontend/src/app/dashboard/component/user/user-dataset/user-dataset-explorer/dataset-detail.component.ts index b4d12f5a28e..bfc97379ecd 100644 --- a/frontend/src/app/dashboard/component/user/user-dataset/user-dataset-explorer/dataset-detail.component.ts +++ b/frontend/src/app/dashboard/component/user/user-dataset/user-dataset-explorer/dataset-detail.component.ts @@ -39,12 +39,14 @@ import { FileUploadItem } from "../../../../type/dashboard-file.interface"; import { DatasetStagedObject } from "../../../../../common/type/dataset-staged-object"; import { NzModalService } from "ng-zorro-antd/modal"; import { AdminSettingsService } from "../../../../service/admin/settings/admin-settings.service"; -import { HttpErrorResponse } from "@angular/common/http"; +import { HttpErrorResponse, HttpStatusCode } from "@angular/common/http"; import { Subscription } from "rxjs"; import { formatSpeed, formatTime } from "src/app/common/util/format.util"; import { format } from "date-fns"; export const THROTTLE_TIME_MS = 1000; +export const ABORT_RETRY_MAX_ATTEMPTS = 10; +export const ABORT_RETRY_BACKOFF_BASE_MS = 100; @UntilDestroy() @Component({ @@ -405,103 +407,107 @@ export class DatasetDetailComponent implements OnInit { if (this.did) { files.forEach(file => { // Check if currently uploading - this.cancelExistingUpload(file.name); - - // Create upload function - const startUpload = () => { - this.pendingQueue = this.pendingQueue.filter(item => item.fileName !== file.name); - - // Add an initializing task placeholder to uploadTasks - this.uploadTasks.unshift({ - filePath: file.name, - percentage: 0, - status: "initializing", - uploadId: "", - physicalAddress: "", - }); - // Start multipart upload - const subscription = this.datasetService - .multipartUpload( - this.ownerEmail, - this.datasetName, - file.name, - file.file, - this.chunkSizeMiB * 1024 * 1024, - this.maxConcurrentChunks - ) - .pipe(untilDestroyed(this)) - .subscribe({ - next: progress => { - // Find the task - const taskIndex = this.uploadTasks.findIndex(t => t.filePath === file.name); - - if (taskIndex !== -1) { - // Update the task with new progress info - this.uploadTasks[taskIndex] = { - ...this.uploadTasks[taskIndex], - ...progress, - percentage: progress.percentage ?? this.uploadTasks[taskIndex].percentage ?? 0, - }; - - // Auto-hide when upload is truly finished - if (progress.status === "finished" && progress.totalTime) { - const filename = file.name.split("/").pop() || file.name; - this.uploadTimeMap.set(filename, progress.totalTime); + const continueWithUpload = () => { + // Create upload function + const startUpload = () => { + this.pendingQueue = this.pendingQueue.filter(item => item.fileName !== file.name); + + // Add an initializing task placeholder to uploadTasks + this.uploadTasks.unshift({ + filePath: file.name, + percentage: 0, + status: "initializing", + }); + // Start multipart upload + const subscription = this.datasetService + .multipartUpload( + this.ownerEmail, + this.datasetName, + file.name, + file.file, + this.chunkSizeMiB * 1024 * 1024, + this.maxConcurrentChunks + ) + .pipe(untilDestroyed(this)) + .subscribe({ + next: progress => { + // Find the task + const taskIndex = this.uploadTasks.findIndex(t => t.filePath === file.name); + + if (taskIndex !== -1) { + // Update the task with new progress info + this.uploadTasks[taskIndex] = { + ...this.uploadTasks[taskIndex], + ...progress, + percentage: progress.percentage ?? this.uploadTasks[taskIndex].percentage ?? 0, + }; + + // Auto-hide when upload is truly finished + if (progress.status === "finished" && progress.totalTime) { + const filename = file.name.split("/").pop() || file.name; + this.uploadTimeMap.set(filename, progress.totalTime); + this.userMakeChanges.emit(); + this.scheduleHide(taskIndex); + this.onUploadComplete(); + } + } + }, + error: () => { + // Handle upload error + const taskIndex = this.uploadTasks.findIndex(t => t.filePath === file.name); + + if (taskIndex !== -1) { + this.uploadTasks[taskIndex] = { + ...this.uploadTasks[taskIndex], + percentage: 100, + status: "aborted", + }; + this.scheduleHide(taskIndex); + } + this.onUploadComplete(); + }, + complete: () => { + const taskIndex = this.uploadTasks.findIndex(t => t.filePath === file.name); + if (taskIndex !== -1 && this.uploadTasks[taskIndex].status !== "finished") { + this.uploadTasks[taskIndex].status = "finished"; this.userMakeChanges.emit(); this.scheduleHide(taskIndex); this.onUploadComplete(); } - } - }, - error: () => { - // Handle upload error - const taskIndex = this.uploadTasks.findIndex(t => t.filePath === file.name); - - if (taskIndex !== -1) { - this.uploadTasks[taskIndex] = { - ...this.uploadTasks[taskIndex], - percentage: 100, - status: "aborted", - }; - this.scheduleHide(taskIndex); - } - this.onUploadComplete(); - }, - complete: () => { - const taskIndex = this.uploadTasks.findIndex(t => t.filePath === file.name); - if (taskIndex !== -1 && this.uploadTasks[taskIndex].status !== "finished") { - this.uploadTasks[taskIndex].status = "finished"; - this.userMakeChanges.emit(); - this.scheduleHide(taskIndex); - this.onUploadComplete(); - } - }, - }); - // Store the subscription for later cleanup - this.uploadSubscriptions.set(file.name, subscription); + }, + }); + // Store the subscription for later cleanup + this.uploadSubscriptions.set(file.name, subscription); + }; + + // Queue management + if (this.activeUploads < this.maxConcurrentFiles) { + this.activeUploads++; + startUpload(); + } else { + this.pendingQueue.push({ fileName: file.name, startUpload }); + } }; - // Queue management - if (this.activeUploads < this.maxConcurrentFiles) { - this.activeUploads++; - startUpload(); - } else { - this.pendingQueue.push({ fileName: file.name, startUpload }); - } + // Check if currently uploading + this.cancelExistingUpload(file.name, continueWithUpload); }); } } - cancelExistingUpload(fileName: string): void { + cancelExistingUpload(fileName: string, onCanceled?: () => void): void { const task = this.uploadTasks.find(t => t.filePath === fileName); if (task) { if (task.status === "uploading" || task.status === "initializing") { - this.onClickAbortUploadProgress(task); + this.onClickAbortUploadProgress(task, onCanceled); return; } } // Remove from pending queue if present this.pendingQueue = this.pendingQueue.filter(item => item.fileName !== fileName); + if (onCanceled) { + onCanceled(); + } } private processNextQueuedUpload(): void { @@ -547,7 +553,7 @@ export class DatasetDetailComponent implements OnInit { }, 5000); } - onClickAbortUploadProgress(task: MultipartUploadProgress & { filePath: string }) { + onClickAbortUploadProgress(task: MultipartUploadProgress & { filePath: string }, onAborted?: () => void) { const subscription = this.uploadSubscriptions.get(task.filePath); if (subscription) { subscription.unsubscribe(); @@ -558,21 +564,54 @@ export class DatasetDetailComponent implements OnInit { this.onUploadComplete(); } - this.datasetService - .finalizeMultipartUpload( - this.ownerEmail, - this.datasetName, - task.filePath, - task.uploadId, - [], - task.physicalAddress, - true // abort flag - ) - .pipe(untilDestroyed(this)) - .subscribe(() => { - this.notificationService.info(`${task.filePath} uploading has been terminated`); - }); - // Remove the aborted task immediately + let doneCalled = false; + const done = () => { + if (doneCalled) { + return; + } + doneCalled = true; + if (onAborted) { + onAborted(); + } + }; + + const abortWithRetry = (attempt: number) => { + this.datasetService + .finalizeMultipartUpload( + this.ownerEmail, + this.datasetName, + task.filePath, + true // abort flag + ) + .pipe(untilDestroyed(this)) + .subscribe({ + next: () => { + this.notificationService.info(`${task.filePath} uploading has been terminated`); + done(); + }, + error: (res: unknown) => { + const err = res as HttpErrorResponse; + + // Already gone, treat as done + if (err.status === 404) { + done(); + return; + } + + // Backend is still finalizing/aborting; retry with a tiny backoff + if (err.status === HttpStatusCode.Conflict && attempt < ABORT_RETRY_MAX_ATTEMPTS) { + setTimeout(() => abortWithRetry(attempt + 1), ABORT_RETRY_BACKOFF_BASE_MS * (attempt + 1)); + return; + } + + // Keep current UX: still consider it "aborted" client-side + done(); + }, + }); + }; + + abortWithRetry(0); + this.uploadTasks = this.uploadTasks.filter(t => t.filePath !== task.filePath); } diff --git a/frontend/src/app/dashboard/service/user/dataset/dataset.service.ts b/frontend/src/app/dashboard/service/user/dataset/dataset.service.ts index c09125d73b1..97b2e264b75 100644 --- a/frontend/src/app/dashboard/service/user/dataset/dataset.service.ts +++ b/frontend/src/app/dashboard/service/user/dataset/dataset.service.ts @@ -18,7 +18,7 @@ */ import { Injectable } from "@angular/core"; -import { HttpClient, HttpParams } from "@angular/common/http"; +import { HttpClient, HttpErrorResponse, HttpParams } from "@angular/common/http"; import { catchError, map, mergeMap, switchMap, tap, toArray } from "rxjs/operators"; import { Dataset, DatasetVersion } from "../../../../common/type/dataset"; import { AppSettings } from "../../../../common/app-setting"; @@ -27,6 +27,7 @@ import { DashboardDataset } from "../../../type/dashboard-dataset.interface"; import { DatasetFileNode } from "../../../../common/type/datasetVersionFileTree"; import { DatasetStagedObject } from "../../../../common/type/dataset-staged-object"; import { GuiConfigService } from "../../../../common/service/gui-config.service"; +import { AuthService } from "src/app/common/service/user/auth.service"; export const DATASET_BASE_URL = "dataset"; export const DATASET_CREATE_URL = DATASET_BASE_URL + "/create"; @@ -51,8 +52,6 @@ export interface MultipartUploadProgress { filePath: string; percentage: number; status: "initializing" | "uploading" | "finished" | "aborted"; - uploadId: string; - physicalAddress: string; uploadSpeed?: number; // bytes per second estimatedTimeRemaining?: number; // seconds totalTime?: number; // total seconds taken @@ -122,6 +121,7 @@ export class DatasetService { public retrieveAccessibleDatasets(): Observable { return this.http.get(`${AppSettings.getApiEndpoint()}/${DATASET_LIST_URL}`); } + public createDatasetVersion(did: number, newVersion: string): Observable { return this.http .post<{ @@ -141,6 +141,12 @@ export class DatasetService { /** * Handles multipart upload for large files using RxJS, * with a concurrency limit on how many parts we process in parallel. + * + * Backend flow: + * POST /dataset/multipart-upload?type=init&ownerEmail=...&datasetName=...&filePath=...&numParts=N + * POST /dataset/multipart-upload/part?ownerEmail=...&datasetName=...&filePath=...&partNumber= (body: raw chunk) + * POST /dataset/multipart-upload?type=finish&ownerEmail=...&datasetName=...&filePath=... + * POST /dataset/multipart-upload?type=abort&ownerEmail=...&datasetName=...&filePath=... */ public multipartUpload( ownerEmail: string, @@ -152,8 +158,8 @@ export class DatasetService { ): Observable { const partCount = Math.ceil(file.size / partSize); - return new Observable(observer => { - // Track upload progress for each part independently + return new Observable(observer => { + // Track upload progress (bytes) for each part independently const partProgress = new Map(); // Progress tracking state @@ -162,8 +168,15 @@ export class DatasetService { let lastETA = 0; let lastUpdateTime = 0; - // Calculate stats with smoothing + const lastStats = { + uploadSpeed: 0, + estimatedTimeRemaining: 0, + totalTime: 0, + }; + const getTotalTime = () => (startTime ? (Date.now() - startTime) / 1000 : 0); + + // Calculate stats with smoothing and simple throttling (~1s) const calculateStats = (totalUploaded: number) => { if (startTime === null) { startTime = Date.now(); @@ -172,25 +185,25 @@ export class DatasetService { const now = Date.now(); const elapsed = getTotalTime(); - // Throttle updates to every 1s const shouldUpdate = now - lastUpdateTime >= 1000; if (!shouldUpdate) { - return null; + // keep totalTime fresh even when throttled + lastStats.totalTime = elapsed; + return lastStats; } lastUpdateTime = now; - // Calculate speed with moving average const currentSpeed = elapsed > 0 ? totalUploaded / elapsed : 0; speedSamples.push(currentSpeed); - if (speedSamples.length > 5) speedSamples.shift(); - const avgSpeed = speedSamples.reduce((a, b) => a + b, 0) / speedSamples.length; + if (speedSamples.length > 5) { + speedSamples.shift(); + } + const avgSpeed = speedSamples.length > 0 ? speedSamples.reduce((a, b) => a + b, 0) / speedSamples.length : 0; - // Calculate smooth ETA const remaining = file.size - totalUploaded; let eta = avgSpeed > 0 ? remaining / avgSpeed : 0; - eta = Math.min(eta, 24 * 60 * 60); // cap ETA at 24h, 86400 sec + eta = Math.min(eta, 24 * 60 * 60); // cap ETA at 24h - // Smooth ETA changes (limit to 30% change) if (lastETA > 0 && eta > 0) { const maxChange = lastETA * 0.3; const diff = Math.abs(eta - lastETA); @@ -200,106 +213,118 @@ export class DatasetService { } lastETA = eta; - // Near completion optimization const percentComplete = (totalUploaded / file.size) * 100; if (percentComplete > 95) { eta = Math.min(eta, 10); } - return { - uploadSpeed: avgSpeed, - estimatedTimeRemaining: Math.max(0, Math.round(eta)), - totalTime: elapsed, - }; + lastStats.uploadSpeed = avgSpeed; + lastStats.estimatedTimeRemaining = Math.max(0, Math.round(eta)); + lastStats.totalTime = elapsed; + + return lastStats; }; - const subscription = this.initiateMultipartUpload(ownerEmail, datasetName, filePath, partCount) + // 1. INIT: ask backend to create a LakeFS multipart upload session + const initParams = new HttpParams() + .set("type", "init") + .set("ownerEmail", ownerEmail) + .set("datasetName", datasetName) + .set("filePath", encodeURIComponent(filePath)) + .set("numParts", partCount.toString()); + + const init$ = this.http.post<{}>( + `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, + {}, + { params: initParams } + ); + + const initWithAbortRetry$ = init$.pipe( + catchError((res: unknown) => { + const err = res as HttpErrorResponse; + if (err.status !== 409) { + return throwError(() => err); + } + + // Init failed because a session already exists. Abort it and retry init once. + return this.finalizeMultipartUpload(ownerEmail, datasetName, filePath, true).pipe( + // best-effort abort; if abort itself fails, let the re-init decide + catchError(() => EMPTY), + switchMap(() => init$) + ); + }) + ); + + const subscription = initWithAbortRetry$ .pipe( - switchMap(initiateResponse => { - const { uploadId, presignedUrls, physicalAddress } = initiateResponse; - if (!uploadId) { - observer.error(new Error("Failed to initiate multipart upload")); - return EMPTY; - } + switchMap(initResp => { + // Notify UI that upload is starting observer.next({ - filePath: filePath, + filePath, percentage: 0, status: "initializing", - uploadId: uploadId, - physicalAddress: physicalAddress, uploadSpeed: 0, estimatedTimeRemaining: 0, totalTime: 0, }); - // Keep track of all uploaded parts - const uploadedParts: { PartNumber: number; ETag: string }[] = []; - - // 1) Convert presignedUrls into a stream of URLs - return from(presignedUrls).pipe( - // 2) Use mergeMap with concurrency limit to upload chunk by chunk - mergeMap((url, index) => { + // 2. Upload each part to /multipart-upload/part using XMLHttpRequest + return from(Array.from({ length: partCount }, (_, i) => i)).pipe( + mergeMap(index => { const partNumber = index + 1; const start = index * partSize; const end = Math.min(start + partSize, file.size); const chunk = file.slice(start, end); - // Upload the chunk - return new Observable(partObserver => { + return new Observable(partObserver => { const xhr = new XMLHttpRequest(); xhr.upload.addEventListener("progress", event => { if (event.lengthComputable) { - // Update this specific part's progress partProgress.set(partNumber, event.loaded); - // Calculate total progress across all parts let totalUploaded = 0; - partProgress.forEach(bytes => (totalUploaded += bytes)); + partProgress.forEach(bytes => { + totalUploaded += bytes; + }); + const percentage = Math.round((totalUploaded / file.size) * 100); const stats = calculateStats(totalUploaded); observer.next({ filePath, - percentage: Math.min(percentage, 99), // Cap at 99% until finalized + percentage: Math.min(percentage, 99), status: "uploading", - uploadId, - physicalAddress, ...stats, }); } }); xhr.addEventListener("load", () => { - if (xhr.status === 200 || xhr.status === 201) { - const etag = xhr.getResponseHeader("ETag")?.replace(/"/g, ""); - if (!etag) { - partObserver.error(new Error(`Missing ETag for part ${partNumber}`)); - return; - } - - // Mark this part as fully uploaded + if (xhr.status === 200 || xhr.status === 204) { + // Mark part as fully uploaded partProgress.set(partNumber, chunk.size); - uploadedParts.push({ PartNumber: partNumber, ETag: etag }); - // Recalculate progress let totalUploaded = 0; - partProgress.forEach(bytes => (totalUploaded += bytes)); - const percentage = Math.round((totalUploaded / file.size) * 100); + partProgress.forEach(bytes => { + totalUploaded += bytes; + }); + + // Force stats recompute on completion lastUpdateTime = 0; + const percentage = Math.round((totalUploaded / file.size) * 100); const stats = calculateStats(totalUploaded); observer.next({ filePath, percentage: Math.min(percentage, 99), status: "uploading", - uploadId, - physicalAddress, ...stats, }); + partObserver.complete(); } else { - partObserver.error(new Error(`Failed to upload part ${partNumber}`)); + partObserver.error(new Error(`Failed to upload part ${partNumber} (HTTP ${xhr.status})`)); } }); @@ -309,60 +334,88 @@ export class DatasetService { partObserver.error(new Error(`Failed to upload part ${partNumber}`)); }); - xhr.open("PUT", url); + const partUrl = + `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload/part` + + `?ownerEmail=${encodeURIComponent(ownerEmail)}` + + `&datasetName=${encodeURIComponent(datasetName)}` + + `&filePath=${encodeURIComponent(filePath)}` + + `&partNumber=${partNumber}`; + + xhr.open("POST", partUrl); + xhr.setRequestHeader("Content-Type", "application/octet-stream"); + const token = AuthService.getAccessToken(); + if (token) { + xhr.setRequestHeader("Authorization", `Bearer ${token}`); + } xhr.send(chunk); + return () => { + try { + xhr.abort(); + } catch {} + }; }); }, concurrencyLimit), - - // 3) Collect results from all uploads (like forkJoin, but respects concurrency) - toArray(), - // 4) Finalize if all parts succeeded - switchMap(() => - this.finalizeMultipartUpload( - ownerEmail, - datasetName, - filePath, - uploadId, - uploadedParts, - physicalAddress, - false - ) - ), + toArray(), // wait for all parts + // 3. FINISH: notify backend that all parts are done + switchMap(() => { + const finishParams = new HttpParams() + .set("type", "finish") + .set("ownerEmail", ownerEmail) + .set("datasetName", datasetName) + .set("filePath", encodeURIComponent(filePath)); + + return this.http.post( + `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, + {}, + { params: finishParams } + ); + }), tap(() => { + const totalTime = getTotalTime(); observer.next({ filePath, percentage: 100, status: "finished", - uploadId: uploadId, - physicalAddress: physicalAddress, uploadSpeed: 0, estimatedTimeRemaining: 0, - totalTime: getTotalTime(), + totalTime, }); observer.complete(); }), catchError((error: unknown) => { - // If an error occurred, abort the upload + // On error, compute best-effort percentage from bytes we've seen + let totalUploaded = 0; + partProgress.forEach(bytes => { + totalUploaded += bytes; + }); + const percentage = file.size > 0 ? Math.round((totalUploaded / file.size) * 100) : 0; + observer.next({ filePath, - percentage: Math.round((uploadedParts.length / partCount) * 100), + percentage, status: "aborted", - uploadId: uploadId, - physicalAddress: physicalAddress, uploadSpeed: 0, estimatedTimeRemaining: 0, totalTime: getTotalTime(), }); - return this.finalizeMultipartUpload( - ownerEmail, - datasetName, - filePath, - uploadId, - uploadedParts, - physicalAddress, - true - ).pipe(switchMap(() => throwError(() => error))); + // Abort on backend + const abortParams = new HttpParams() + .set("type", "abort") + .set("ownerEmail", ownerEmail) + .set("datasetName", datasetName) + .set("filePath", encodeURIComponent(filePath)); + + return this.http + .post( + `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, + {}, + { params: abortParams } + ) + .pipe( + switchMap(() => throwError(() => error)), + catchError(() => throwError(() => error)) + ); }) ); }) @@ -370,59 +423,26 @@ export class DatasetService { .subscribe({ error: (err: unknown) => observer.error(err), }); + return () => subscription.unsubscribe(); }); } - /** - * Initiates a multipart upload and retrieves presigned URLs for each part. - * @param ownerEmail Owner's email - * @param datasetName Dataset Name - * @param filePath File path within the dataset - * @param numParts Number of parts for the multipart upload - */ - private initiateMultipartUpload( - ownerEmail: string, - datasetName: string, - filePath: string, - numParts: number - ): Observable<{ uploadId: string; presignedUrls: string[]; physicalAddress: string }> { - const params = new HttpParams() - .set("type", "init") - .set("ownerEmail", ownerEmail) - .set("datasetName", datasetName) - .set("filePath", encodeURIComponent(filePath)) - .set("numParts", numParts.toString()); - - return this.http.post<{ uploadId: string; presignedUrls: string[]; physicalAddress: string }>( - `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, - {}, - { params } - ); - } - - /** - * Completes or aborts a multipart upload, sending part numbers and ETags to the backend. - */ public finalizeMultipartUpload( ownerEmail: string, datasetName: string, filePath: string, - uploadId: string, - parts: { PartNumber: number; ETag: string }[], - physicalAddress: string, isAbort: boolean ): Observable { const params = new HttpParams() .set("type", isAbort ? "abort" : "finish") .set("ownerEmail", ownerEmail) .set("datasetName", datasetName) - .set("filePath", encodeURIComponent(filePath)) - .set("uploadId", uploadId); + .set("filePath", encodeURIComponent(filePath)); return this.http.post( `${AppSettings.getApiEndpoint()}/${DATASET_BASE_URL}/multipart-upload`, - { parts, physicalAddress }, + {}, { params } ); } diff --git a/sql/texera_ddl.sql b/sql/texera_ddl.sql index 48e51dca873..57ac69b6876 100644 --- a/sql/texera_ddl.sql +++ b/sql/texera_ddl.sql @@ -58,6 +58,9 @@ DROP TABLE IF EXISTS workflow_version CASCADE; DROP TABLE IF EXISTS project CASCADE; DROP TABLE IF EXISTS workflow_of_project CASCADE; DROP TABLE IF EXISTS workflow_executions CASCADE; +DROP TABLE IF EXISTS dataset_upload_session CASCADE; +DROP TABLE IF EXISTS dataset_upload_session_part CASCADE; + DROP TABLE IF EXISTS dataset CASCADE; DROP TABLE IF EXISTS dataset_user_access CASCADE; DROP TABLE IF EXISTS dataset_version CASCADE; @@ -275,6 +278,36 @@ CREATE TABLE IF NOT EXISTS dataset_version FOREIGN KEY (did) REFERENCES dataset(did) ON DELETE CASCADE ); +CREATE TABLE IF NOT EXISTS dataset_upload_session +( + did INT NOT NULL, + uid INT NOT NULL, + file_path TEXT NOT NULL, + upload_id VARCHAR(256) NOT NULL UNIQUE, + physical_address TEXT, + num_parts_requested INT NOT NULL, + + PRIMARY KEY (uid, did, file_path), + + FOREIGN KEY (did) REFERENCES dataset(did) ON DELETE CASCADE, + FOREIGN KEY (uid) REFERENCES "user"(uid) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS dataset_upload_session_part +( + upload_id VARCHAR(256) NOT NULL, + part_number INT NOT NULL, + etag TEXT NOT NULL DEFAULT '', + + PRIMARY KEY (upload_id, part_number), + + CONSTRAINT chk_part_number_positive CHECK (part_number > 0), + + FOREIGN KEY (upload_id) + REFERENCES dataset_upload_session(upload_id) + ON DELETE CASCADE +); + -- operator_executions (modified to match MySQL: no separate primary key; added console_messages_uri) CREATE TABLE IF NOT EXISTS operator_executions ( diff --git a/sql/updates/17.sql b/sql/updates/17.sql new file mode 100644 index 00000000000..9436c405286 --- /dev/null +++ b/sql/updates/17.sql @@ -0,0 +1,66 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ============================================ +-- 1. Connect to the texera_db database +-- ============================================ +\c texera_db + +SET search_path TO texera_db; + +-- ============================================ +-- 2. Update the table schema +-- ============================================ +BEGIN; + +-- 1. Drop old tables (if exist) +DROP TABLE IF EXISTS dataset_upload_session CASCADE; +DROP TABLE IF EXISTS dataset_upload_session_part CASCADE; + +-- 2. Create dataset upload session table +CREATE TABLE IF NOT EXISTS dataset_upload_session +( + did INT NOT NULL, + uid INT NOT NULL, + file_path TEXT NOT NULL, + upload_id VARCHAR(256) NOT NULL UNIQUE, + physical_address TEXT, + num_parts_requested INT NOT NULL, + + PRIMARY KEY (uid, did, file_path), + + FOREIGN KEY (did) REFERENCES dataset(did) ON DELETE CASCADE, + FOREIGN KEY (uid) REFERENCES "user"(uid) ON DELETE CASCADE + ); + +-- 3. Create dataset upload session parts table +CREATE TABLE IF NOT EXISTS dataset_upload_session_part +( + upload_id VARCHAR(256) NOT NULL, + part_number INT NOT NULL, + etag TEXT NOT NULL DEFAULT '', + + PRIMARY KEY (upload_id, part_number), + + CONSTRAINT chk_part_number_positive CHECK (part_number > 0), + + FOREIGN KEY (upload_id) + REFERENCES dataset_upload_session(upload_id) + ON DELETE CASCADE + ); + +COMMIT;