diff --git a/plugins/repository-azure/build.gradle b/plugins/repository-azure/build.gradle index 7dfaca37a25ab..8e417c5217319 100644 --- a/plugins/repository-azure/build.gradle +++ b/plugins/repository-azure/build.gradle @@ -410,3 +410,18 @@ task azureThirdPartyTest(type: Test) { } } tasks.named("check").configure { dependsOn("azureThirdPartyTest") } + +// test jar is exported by the integTestArtifacts configuration to be used in the encrypted Azure repository test +configurations { + internalClusterTestArtifacts.extendsFrom internalClusterTestImplementation + internalClusterTestArtifacts.extendsFrom internalClusterTestRuntime +} + +def internalClusterTestJar = tasks.register("internalClusterTestJar", Jar) { + appendix 'internalClusterTest' + from sourceSets.internalClusterTest.output +} + +artifacts { + internalClusterTestArtifacts internalClusterTestJar +} diff --git a/plugins/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryTests.java b/plugins/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryTests.java index 47d0952eb4c7a..d334929f31f19 100644 --- a/plugins/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryTests.java +++ b/plugins/repository-azure/src/internalClusterTest/java/org/elasticsearch/repositories/azure/AzureBlobStoreRepositoryTests.java @@ -58,12 +58,15 @@ protected String repositoryType() { } @Override - protected Settings repositorySettings() { - return Settings.builder() - .put(super.repositorySettings()) - .put(AzureRepository.Repository.CONTAINER_SETTING.getKey(), "container") - .put(AzureStorageSettings.ACCOUNT_SETTING.getKey(), "test") - .build(); + protected Settings repositorySettings(String repoName) { + Settings.Builder settingsBuilder = Settings.builder() + .put(super.repositorySettings(repoName)) + .put(AzureRepository.Repository.CONTAINER_SETTING.getKey(), "container") + .put(AzureStorageSettings.ACCOUNT_SETTING.getKey(), "test"); + if (randomBoolean()) { + settingsBuilder.put(AzureRepository.Repository.BASE_PATH_SETTING.getKey(), randomFrom("test", "test/1")); + } + return settingsBuilder.build(); } @Override diff --git a/plugins/repository-gcs/build.gradle b/plugins/repository-gcs/build.gradle index 3436f5af7aa32..8cdf7c701a1be 100644 --- a/plugins/repository-gcs/build.gradle +++ b/plugins/repository-gcs/build.gradle @@ -334,3 +334,20 @@ def gcsThirdPartyTest = tasks.register("gcsThirdPartyTest", Test) { tasks.named("check").configure { dependsOn(largeBlobYamlRestTest, gcsThirdPartyTest) } + +// test jar is exported by the integTestArtifacts configuration to be used in the encrypted GCS repository test +configurations { + internalClusterTestArtifacts.extendsFrom internalClusterTestImplementation + internalClusterTestArtifacts.extendsFrom internalClusterTestRuntime +} + +def internalClusterTestJar = tasks.register("internalClusterTestJar", Jar) { + appendix 'internalClusterTest' + from sourceSets.internalClusterTest.output + // for the repositories.gcs.TestUtils class + from sourceSets.test.output +} + +artifacts { + internalClusterTestArtifacts internalClusterTestJar +} diff --git a/plugins/repository-gcs/src/internalClusterTest/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStoreRepositoryTests.java b/plugins/repository-gcs/src/internalClusterTest/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStoreRepositoryTests.java index 36733815ca2aa..830da6f822407 100644 --- a/plugins/repository-gcs/src/internalClusterTest/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStoreRepositoryTests.java +++ b/plugins/repository-gcs/src/internalClusterTest/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStoreRepositoryTests.java @@ -70,6 +70,7 @@ import static org.elasticsearch.repositories.gcs.GoogleCloudStorageClientSettings.CREDENTIALS_FILE_SETTING; import static org.elasticsearch.repositories.gcs.GoogleCloudStorageClientSettings.ENDPOINT_SETTING; import static org.elasticsearch.repositories.gcs.GoogleCloudStorageClientSettings.TOKEN_URI_SETTING; +import static org.elasticsearch.repositories.gcs.GoogleCloudStorageRepository.BASE_PATH; import static org.elasticsearch.repositories.gcs.GoogleCloudStorageRepository.BUCKET; import static org.elasticsearch.repositories.gcs.GoogleCloudStorageRepository.CLIENT_NAME; @@ -94,12 +95,15 @@ protected String repositoryType() { } @Override - protected Settings repositorySettings() { - return Settings.builder() - .put(super.repositorySettings()) - .put(BUCKET.getKey(), "bucket") - .put(CLIENT_NAME.getKey(), "test") - .build(); + protected Settings repositorySettings(String repoName) { + Settings.Builder settingsBuilder = Settings.builder() + .put(super.repositorySettings(repoName)) + .put(BUCKET.getKey(), "bucket") + .put(CLIENT_NAME.getKey(), "test"); + if (randomBoolean()) { + settingsBuilder.put(BASE_PATH.getKey(), randomFrom("test", "test/1")); + } + return settingsBuilder.build(); } @Override @@ -135,7 +139,7 @@ protected Settings nodeSettings(int nodeOrdinal) { } public void testDeleteSingleItem() { - final String repoName = createRepository(randomName()); + final String repoName = createRepository(randomRepositoryName()); final RepositoriesService repositoriesService = internalCluster().getMasterNodeInstance(RepositoriesService.class); final BlobStoreRepository repository = (BlobStoreRepository) repositoriesService.repository(repoName); PlainActionFuture.get(f -> repository.threadPool().generic().execute(ActionRunnable.run(f, () -> diff --git a/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java b/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java index a427cee824209..e87c2ac5c1b29 100644 --- a/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java +++ b/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java @@ -38,7 +38,7 @@ protected String repositoryType() { } @Override - protected Settings repositorySettings() { + protected Settings repositorySettings(String repoName) { return Settings.builder() .put("uri", "hdfs:///") .put("conf.fs.AbstractFileSystem.hdfs.impl", TestingFs.class.getName()) @@ -47,6 +47,12 @@ protected Settings repositorySettings() { .put("compress", randomBoolean()).build(); } + @Override + public void testSnapshotAndRestore() throws Exception { + // the HDFS mockup doesn't preserve the repository contents after removing the repository + testSnapshotAndRestore(false); + } + @Override protected Collection> nodePlugins() { return Collections.singletonList(HdfsPlugin.class); diff --git a/plugins/repository-s3/build.gradle b/plugins/repository-s3/build.gradle index 02da20c7f1192..be5120170a463 100644 --- a/plugins/repository-s3/build.gradle +++ b/plugins/repository-s3/build.gradle @@ -432,3 +432,20 @@ tasks.named("thirdPartyAudit").configure { ignoreMissingClasses 'javax.activation.DataHandler' } } + +// test jar is exported by the integTestArtifacts configuration to be used in the encrypted S3 repository test +configurations { + internalClusterTestArtifacts.extendsFrom internalClusterTestImplementation + internalClusterTestArtifacts.extendsFrom internalClusterTestRuntime +} + +def internalClusterTestJar = tasks.register("internalClusterTestJar", Jar) { + appendix 'internalClusterTest' + from sourceSets.internalClusterTest.output + // for the plugin-security.policy resource + from sourceSets.test.output +} + +artifacts { + internalClusterTestArtifacts internalClusterTestJar +} diff --git a/plugins/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java b/plugins/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java index 6dcef8ddf554b..4a8f308ce1bba 100644 --- a/plugins/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java +++ b/plugins/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java @@ -94,14 +94,17 @@ protected String repositoryType() { } @Override - protected Settings repositorySettings() { - return Settings.builder() - .put(super.repositorySettings()) + protected Settings repositorySettings(String repoName) { + Settings.Builder settingsBuilder = Settings.builder() + .put(super.repositorySettings(repoName)) .put(S3Repository.BUCKET_SETTING.getKey(), "bucket") .put(S3Repository.CLIENT_NAME.getKey(), "test") // Don't cache repository data because some tests manually modify the repository data - .put(BlobStoreRepository.CACHE_REPOSITORY_DATA.getKey(), false) - .build(); + .put(BlobStoreRepository.CACHE_REPOSITORY_DATA.getKey(), false); + if (randomBoolean()) { + settingsBuilder.put(S3Repository.BASE_PATH_SETTING.getKey(), randomFrom("test", "test/1")); + } + return settingsBuilder.build(); } @Override @@ -145,8 +148,9 @@ protected Settings nodeSettings(int nodeOrdinal) { } public void testEnforcedCooldownPeriod() throws IOException { - final String repoName = createRepository(randomName(), Settings.builder().put(repositorySettings()) - .put(S3Repository.COOLDOWN_PERIOD.getKey(), TEST_COOLDOWN_PERIOD).build()); + final String repoName = randomRepositoryName(); + createRepository(repoName, Settings.builder().put(repositorySettings(repoName)) + .put(S3Repository.COOLDOWN_PERIOD.getKey(), TEST_COOLDOWN_PERIOD).build(), true); final SnapshotId fakeOldSnapshot = client().admin().cluster().prepareCreateSnapshot(repoName, "snapshot-old") .setWaitForCompletion(true).setIndices().get().getSnapshotInfo().snapshotId(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/repositories/fs/FsBlobStoreRepositoryIntegTests.java b/server/src/internalClusterTest/java/org/elasticsearch/repositories/fs/FsBlobStoreRepositoryIntegTests.java new file mode 100644 index 0000000000000..268f47130601d --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/repositories/fs/FsBlobStoreRepositoryIntegTests.java @@ -0,0 +1,40 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.repositories.fs; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.repositories.blobstore.ESFsBasedRepositoryIntegTestCase; + +public class FsBlobStoreRepositoryIntegTests extends ESFsBasedRepositoryIntegTestCase { + + @Override + protected Settings repositorySettings(String repositoryName) { + final Settings.Builder settings = Settings.builder() + .put("compress", randomBoolean()) + .put("location", randomRepoPath()); + if (randomBoolean()) { + long size = 1 << randomInt(10); + settings.put("chunk_size", new ByteSizeValue(size, ByteSizeUnit.KB)); + } + return settings.build(); + } +} diff --git a/server/src/main/java/org/elasticsearch/common/blobstore/BlobPath.java b/server/src/main/java/org/elasticsearch/common/blobstore/BlobPath.java index b918bea593f22..4004347ee62dd 100644 --- a/server/src/main/java/org/elasticsearch/common/blobstore/BlobPath.java +++ b/server/src/main/java/org/elasticsearch/common/blobstore/BlobPath.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.Objects; /** * The list of paths where a blob can reside. The contents of the paths are dependent upon the implementation of {@link BlobContainer}. @@ -91,4 +92,17 @@ public String toString() { } return sb.toString(); } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BlobPath other = (BlobPath) o; + return paths.equals(other.paths); + } + + @Override + public int hashCode() { + return Objects.hash(paths); + } } diff --git a/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java index da0316be5c419..9fae32ff61f3c 100644 --- a/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java @@ -41,6 +41,7 @@ import org.elasticsearch.repositories.RepositoriesService; import org.elasticsearch.repositories.Repository; import org.elasticsearch.repositories.RepositoryData; +import org.elasticsearch.repositories.RepositoryMissingException; import org.elasticsearch.snapshots.SnapshotMissingException; import org.elasticsearch.snapshots.SnapshotRestoreException; import org.elasticsearch.test.ESIntegTestCase; @@ -78,17 +79,19 @@ public static RepositoryData getRepositoryData(Repository repository) { protected abstract String repositoryType(); - protected Settings repositorySettings() { + protected Settings repositorySettings(String repoName) { return Settings.builder().put("compress", randomBoolean()).build(); } protected final String createRepository(final String name) { - return createRepository(name, repositorySettings()); + return createRepository(name, true); } - protected final String createRepository(final String name, final Settings settings) { - final boolean verify = randomBoolean(); + protected final String createRepository(final String name, final boolean verify) { + return createRepository(name, repositorySettings(name), verify); + } + protected final String createRepository(final String name, final Settings settings, final boolean verify) { logger.info("--> creating repository [name: {}, verify: {}, settings: {}]", name, verify, settings); assertAcked(client().admin().cluster().preparePutRepository(name) .setType(repositoryType()) @@ -98,7 +101,7 @@ protected final String createRepository(final String name, final Settings settin internalCluster().getDataOrMasterNodeInstances(RepositoriesService.class).forEach(repositories -> { assertThat(repositories.repository(name), notNullValue()); assertThat(repositories.repository(name), instanceOf(BlobStoreRepository.class)); - assertThat(repositories.repository(name).isReadOnly(), is(false)); + assertThat(repositories.repository(name).isReadOnly(), is(settings.getAsBoolean("readonly", false))); BlobStore blobStore = ((BlobStoreRepository) repositories.repository(name)).getBlobStore(); assertThat("blob store has to be lazy initialized", blobStore, verify ? is(notNullValue()) : is(nullValue())); }); @@ -106,6 +109,15 @@ protected final String createRepository(final String name, final Settings settin return name; } + protected final void deleteRepository(final String name) { + logger.debug("--> deleting repository [name: {}]", name); + assertAcked(client().admin().cluster().prepareDeleteRepository(name)); + internalCluster().getDataOrMasterNodeInstances(RepositoriesService.class).forEach(repositories -> { + RepositoryMissingException e = expectThrows(RepositoryMissingException.class, () -> repositories.repository(name)); + assertThat(e.repository(), equalTo(name)); + }); + } + public void testReadNonExistingPath() throws IOException { try (BlobStore store = newBlobStore()) { final BlobContainer container = store.blobContainer(new BlobPath()); @@ -176,7 +188,7 @@ public void testList() throws IOException { BlobMetadata blobMetadata = blobs.get(generated.getKey()); assertThat(generated.getKey(), blobMetadata, CoreMatchers.notNullValue()); assertThat(blobMetadata.name(), CoreMatchers.equalTo(generated.getKey())); - assertThat(blobMetadata.length(), CoreMatchers.equalTo(generated.getValue())); + assertThat(blobMetadata.length(), CoreMatchers.equalTo(blobLengthFromContentLength(generated.getValue()))); } assertThat(container.listBlobsByPrefix("foo-").size(), CoreMatchers.equalTo(numberOfFooBlobs)); @@ -259,7 +271,11 @@ protected static void writeBlob(BlobContainer container, String blobName, BytesA } protected BlobStore newBlobStore() { - final String repository = createRepository(randomName()); + final String repository = createRepository(randomRepositoryName()); + return newBlobStore(repository); + } + + protected BlobStore newBlobStore(String repository) { final BlobStoreRepository blobStoreRepository = (BlobStoreRepository) internalCluster().getMasterNodeInstance(RepositoriesService.class).repository(repository); return PlainActionFuture.get( @@ -267,7 +283,13 @@ protected BlobStore newBlobStore() { } public void testSnapshotAndRestore() throws Exception { - final String repoName = createRepository(randomName()); + testSnapshotAndRestore(randomBoolean()); + } + + protected void testSnapshotAndRestore(boolean recreateRepositoryBeforeRestore) throws Exception { + final String repoName = randomRepositoryName(); + final Settings repoSettings = repositorySettings(repoName); + createRepository(repoName, repoSettings, randomBoolean()); int indexCount = randomIntBetween(1, 5); int[] docCounts = new int[indexCount]; String[] indexNames = generateRandomNames(indexCount); @@ -315,6 +337,11 @@ public void testSnapshotAndRestore() throws Exception { assertAcked(client().admin().indices().prepareClose(closeIndices.toArray(new String[closeIndices.size()]))); } + if (recreateRepositoryBeforeRestore) { + deleteRepository(repoName); + createRepository(repoName, repoSettings, randomBoolean()); + } + logger.info("--> restore all indices from the snapshot"); assertSuccessfulRestore(client().admin().cluster().prepareRestoreSnapshot(repoName, snapshotName).setWaitForCompletion(true)); @@ -339,7 +366,7 @@ public void testSnapshotAndRestore() throws Exception { } public void testMultipleSnapshotAndRollback() throws Exception { - final String repoName = createRepository(randomName()); + final String repoName = createRepository(randomRepositoryName()); int iterationCount = randomIntBetween(2, 5); int[] docCounts = new int[iterationCount]; String indexName = randomName(); @@ -394,7 +421,7 @@ public void testMultipleSnapshotAndRollback() throws Exception { } public void testIndicesDeletedFromRepository() throws Exception { - final String repoName = createRepository("test-repo"); + final String repoName = createRepository(randomRepositoryName()); Client client = client(); createIndex("test-idx-1", "test-idx-2", "test-idx-3"); ensureGreen(); @@ -491,7 +518,15 @@ private static void assertSuccessfulRestore(RestoreSnapshotResponse response) { assertThat(response.getRestoreInfo().successfulShards(), equalTo(response.getRestoreInfo().totalShards())); } - protected static String randomName() { + protected String randomName() { return randomAlphaOfLength(randomIntBetween(1, 10)).toLowerCase(Locale.ROOT); } + + protected String randomRepositoryName() { + return randomName(); + } + + protected long blobLengthFromContentLength(long contentLength) { + return contentLength; + } } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/repositories/fs/FsBlobStoreRepositoryIT.java b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESFsBasedRepositoryIntegTestCase.java similarity index 69% rename from server/src/internalClusterTest/java/org/elasticsearch/repositories/fs/FsBlobStoreRepositoryIT.java rename to test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESFsBasedRepositoryIntegTestCase.java index 9ad02412f3771..3a8501d65e95a 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/repositories/fs/FsBlobStoreRepositoryIT.java +++ b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESFsBasedRepositoryIntegTestCase.java @@ -7,7 +7,7 @@ * not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an @@ -16,18 +16,16 @@ * specific language governing permissions and limitations * under the License. */ -package org.elasticsearch.repositories.fs; +package org.elasticsearch.repositories.blobstore; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.blobstore.BlobContainer; import org.elasticsearch.common.blobstore.BlobPath; -import org.elasticsearch.common.blobstore.fs.FsBlobStore; +import org.elasticsearch.common.blobstore.BlobStore; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.unit.ByteSizeUnit; -import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.internal.io.IOUtils; -import org.elasticsearch.repositories.blobstore.ESBlobStoreRepositoryIntegTestCase; +import org.elasticsearch.repositories.fs.FsRepository; import java.io.IOException; import java.nio.file.Files; @@ -39,35 +37,22 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; import static org.hamcrest.Matchers.instanceOf; -public class FsBlobStoreRepositoryIT extends ESBlobStoreRepositoryIntegTestCase { +public abstract class ESFsBasedRepositoryIntegTestCase extends ESBlobStoreRepositoryIntegTestCase { @Override protected String repositoryType() { return FsRepository.TYPE; } - @Override - protected Settings repositorySettings() { - final Settings.Builder settings = Settings.builder(); - settings.put(super.repositorySettings()); - settings.put("location", randomRepoPath()); - if (randomBoolean()) { - long size = 1 << randomInt(10); - settings.put("chunk_size", new ByteSizeValue(size, ByteSizeUnit.KB)); - } - return settings.build(); - } - public void testMissingDirectoriesNotCreatedInReadonlyRepository() throws IOException, InterruptedException { - final String repoName = randomName(); + final String repoName = randomRepositoryName(); final Path repoPath = randomRepoPath(); - logger.info("--> creating repository {} at {}", repoName, repoPath); - - assertAcked(client().admin().cluster().preparePutRepository(repoName).setType("fs").setSettings(Settings.builder() - .put("location", repoPath) - .put("compress", randomBoolean()) - .put("chunk_size", randomIntBetween(100, 1000), ByteSizeUnit.BYTES))); + final Settings repoSettings = Settings.builder() + .put(repositorySettings(repoName)) + .put("location", repoPath) + .build(); + createRepository(repoName, repoSettings, randomBoolean()); final String indexName = randomName(); int docCount = iterations(10, 1000); @@ -91,8 +76,7 @@ public void testMissingDirectoriesNotCreatedInReadonlyRepository() throws IOExce } assertFalse(Files.exists(deletedPath)); - assertAcked(client().admin().cluster().preparePutRepository(repoName).setType("fs").setSettings(Settings.builder() - .put("location", repoPath).put("readonly", true))); + createRepository(repoName, Settings.builder().put(repoSettings).put("readonly", true).build(), randomBoolean()); final ElasticsearchException exception = expectThrows(ElasticsearchException.class, () -> client().admin().cluster().prepareRestoreSnapshot(repoName, snapshotName).setWaitForCompletion(randomBoolean()).get()); @@ -102,25 +86,34 @@ public void testMissingDirectoriesNotCreatedInReadonlyRepository() throws IOExce } public void testReadOnly() throws Exception { - Path tempDir = createTempDir(); - Path path = tempDir.resolve("bar"); - - try (FsBlobStore store = new FsBlobStore(randomIntBetween(1, 8) * 1024, path, true)) { - assertFalse(Files.exists(path)); + final String repoName = randomRepositoryName(); + final Path repoPath = randomRepoPath(); + final Settings repoSettings = Settings.builder() + .put(repositorySettings(repoName)) + .put("readonly", true) + .put(FsRepository.LOCATION_SETTING.getKey(), repoPath) + .put(BlobStoreRepository.BUFFER_SIZE_SETTING.getKey(), String.valueOf(randomIntBetween(1, 8) * 1024) + "kb") + .build(); + createRepository(repoName, repoSettings, false); + + try (BlobStore store = newBlobStore(repoName)) { + assertFalse(Files.exists(repoPath)); BlobPath blobPath = BlobPath.cleanPath().add("foo"); store.blobContainer(blobPath); - Path storePath = store.path(); + Path storePath = repoPath; for (String d : blobPath) { storePath = storePath.resolve(d); } assertFalse(Files.exists(storePath)); } - try (FsBlobStore store = new FsBlobStore(randomIntBetween(1, 8) * 1024, path, false)) { - assertTrue(Files.exists(path)); + createRepository(repoName, Settings.builder().put(repoSettings).put("readonly", false).build(), false); + + try (BlobStore store = newBlobStore(repoName)) { + assertTrue(Files.exists(repoPath)); BlobPath blobPath = BlobPath.cleanPath().add("foo"); BlobContainer container = store.blobContainer(blobPath); - Path storePath = store.path(); + Path storePath = repoPath; for (String d : blobPath) { storePath = storePath.resolve(d); } diff --git a/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESMockAPIBasedRepositoryIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESMockAPIBasedRepositoryIntegTestCase.java index 73f5cb889ddd4..57349764a8cc6 100644 --- a/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESMockAPIBasedRepositoryIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESMockAPIBasedRepositoryIntegTestCase.java @@ -34,12 +34,15 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.network.InetAddresses; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.mocksocket.MockHttpServer; import org.elasticsearch.repositories.RepositoriesService; import org.elasticsearch.repositories.Repository; import org.elasticsearch.repositories.RepositoryMissingException; import org.elasticsearch.repositories.RepositoryStats; import org.elasticsearch.test.BackgroundIndexer; +import org.elasticsearch.threadpool.ThreadPool; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; @@ -55,6 +58,9 @@ import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -83,6 +89,7 @@ protected interface BlobStoreHttpHandler extends HttpHandler { private static final byte[] BUFFER = new byte[1024]; private static HttpServer httpServer; + private static ExecutorService executorService; protected Map handlers; private static final Logger log = LogManager.getLogger(); @@ -90,13 +97,19 @@ protected interface BlobStoreHttpHandler extends HttpHandler { @BeforeClass public static void startHttpServer() throws Exception { httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0); + ThreadFactory threadFactory = EsExecutors.daemonThreadFactory("[" + ESMockAPIBasedRepositoryIntegTestCase.class.getName() + "]"); + // the EncryptedRepository can require more than one connection open at one time + executorService = EsExecutors.newScaling(ESMockAPIBasedRepositoryIntegTestCase.class.getName(), 0, 2, 60, + TimeUnit.SECONDS, threadFactory, new ThreadContext(Settings.EMPTY)); httpServer.setExecutor(r -> { - try { - r.run(); - } catch (Throwable t) { - log.error("Error in execution on mock http server IO thread", t); - throw t; - } + executorService.execute(() -> { + try { + r.run(); + } catch (Throwable t) { + log.error("Error in execution on mock http server IO thread", t); + throw t; + } + }); }); httpServer.start(); } @@ -111,6 +124,7 @@ public void setUpHttpServer() { @AfterClass public static void stopHttpServer() { httpServer.stop(0); + ThreadPool.terminate(executorService, 10, TimeUnit.SECONDS); httpServer = null; } @@ -124,14 +138,17 @@ public void tearDownHttpServer() { h = ((DelegatingHttpHandler) h).getDelegate(); } if (h instanceof BlobStoreHttpHandler) { - List blobs = ((BlobStoreHttpHandler) h).blobs().keySet().stream() - .filter(blob -> blob.contains("index") == false).collect(Collectors.toList()); - assertThat("Only index blobs should remain in repository but found " + blobs, blobs, hasSize(0)); + assertEmptyRepo(((BlobStoreHttpHandler) h).blobs()); } } } } + protected void assertEmptyRepo(Map blobsMap) { + List blobs = blobsMap.keySet().stream().filter(blob -> blob.contains("index") == false).collect(Collectors.toList()); + assertThat("Only index blobs should remain in repository but found " + blobs, blobs, hasSize(0)); + } + protected abstract Map createHttpHandlers(); protected abstract HttpHandler createErroneousHttpHandler(HttpHandler delegate); @@ -139,8 +156,8 @@ public void tearDownHttpServer() { /** * Test the snapshot and restore of an index which has large segments files. */ - public void testSnapshotWithLargeSegmentFiles() throws Exception { - final String repository = createRepository(randomName()); + public final void testSnapshotWithLargeSegmentFiles() throws Exception { + final String repository = createRepository(randomRepositoryName()); final String index = "index-no-merges"; createIndex(index, Settings.builder() .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) @@ -171,7 +188,7 @@ public void testSnapshotWithLargeSegmentFiles() throws Exception { } public void testRequestStats() throws Exception { - final String repository = createRepository(randomName()); + final String repository = createRepository(randomRepositoryName()); final String index = "index-no-merges"; createIndex(index, Settings.builder() .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java index 738a16ff134c4..ed5723426d804 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/license/XPackLicenseState.java @@ -62,6 +62,8 @@ public enum Feature { MONITORING_CLUSTER_ALERTS(OperationMode.STANDARD, true), MONITORING_UPDATE_RETENTION(OperationMode.STANDARD, false), + ENCRYPTED_SNAPSHOT(OperationMode.PLATINUM, true), + CCR(OperationMode.PLATINUM, true), GRAPH(OperationMode.PLATINUM, true), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/license/XPackLicenseStateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/license/XPackLicenseStateTests.java index e8e1bb2ef067a..f680d42810d0f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/license/XPackLicenseStateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/license/XPackLicenseStateTests.java @@ -20,6 +20,7 @@ import java.util.stream.Collectors; import static org.elasticsearch.license.License.OperationMode.BASIC; +import static org.elasticsearch.license.License.OperationMode.ENTERPRISE; import static org.elasticsearch.license.License.OperationMode.GOLD; import static org.elasticsearch.license.License.OperationMode.MISSING; import static org.elasticsearch.license.License.OperationMode.PLATINUM; @@ -299,6 +300,24 @@ public void testWatcherInactivePlatinumGoldTrial() throws Exception { assertAllowed(STANDARD, false, s -> s.checkFeature(Feature.WATCHER), false); } + public void testEncryptedSnapshotsWithInactiveLicense() { + assertAllowed(BASIC, false, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + assertAllowed(TRIAL, false, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + assertAllowed(GOLD, false, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + assertAllowed(PLATINUM, false, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + assertAllowed(ENTERPRISE, false, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + assertAllowed(STANDARD, false, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + } + + public void testEncryptedSnapshotsWithActiveLicense() { + assertAllowed(BASIC, true, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + assertAllowed(TRIAL, true, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), true); + assertAllowed(GOLD, true, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + assertAllowed(PLATINUM, true, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), true); + assertAllowed(ENTERPRISE, true, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), true); + assertAllowed(STANDARD, true, s -> s.checkFeature(Feature.ENCRYPTED_SNAPSHOT), false); + } + public void testGraphPlatinumTrial() throws Exception { assertAllowed(TRIAL, true, s -> s.checkFeature(Feature.GRAPH), true); assertAllowed(PLATINUM, true, s -> s.checkFeature(Feature.GRAPH), true); diff --git a/x-pack/plugin/repository-encrypted/build.gradle b/x-pack/plugin/repository-encrypted/build.gradle new file mode 100644 index 0000000000000..1a0eccc6421e6 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/build.gradle @@ -0,0 +1,30 @@ +evaluationDependsOn(xpackModule('core')) + +apply plugin: 'elasticsearch.esplugin' +apply plugin: 'elasticsearch.internal-cluster-test' +esplugin { + name 'repository-encrypted' + description 'Elasticsearch Expanded Pack Plugin - client-side encrypted repositories.' + classname 'org.elasticsearch.repositories.encrypted.EncryptedRepositoryPlugin' + extendedPlugins = ['x-pack-core'] +} +archivesBaseName = 'x-pack-repository-encrypted' + +dependencies { + // necessary for the license check + compileOnly project(path: xpackModule('core'), configuration: 'default') + testImplementation project(path: xpackModule('core'), configuration: 'testArtifacts') + // required for integ tests of encrypted FS repository + internalClusterTestImplementation project(":test:framework") + // required for integ tests of encrypted cloud repositories + internalClusterTestImplementation project(path: ':plugins:repository-gcs', configuration: 'internalClusterTestArtifacts') + internalClusterTestImplementation project(path: ':plugins:repository-azure', configuration: 'internalClusterTestArtifacts') + internalClusterTestImplementation(project(path: ':plugins:repository-s3', configuration: 'internalClusterTestArtifacts')) { + // HACK, resolves jar hell, such as: + // jar1: jakarta.xml.bind/jakarta.xml.bind-api/2.3.2/8d49996a4338670764d7ca4b85a1c4ccf7fe665d/jakarta.xml.bind-api-2.3.2.jar + // jar2: javax.xml.bind/jaxb-api/2.2.2/aeb3021ca93dde265796d82015beecdcff95bf09/jaxb-api-2.2.2.jar + exclude group: 'javax.xml.bind', module: 'jaxb-api' + } + // for encrypted GCS repository integ tests + internalClusterTestRuntimeOnly 'com.google.guava:guava:26.0-jre' +} diff --git a/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedAzureBlobStoreRepositoryIntegTests.java b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedAzureBlobStoreRepositoryIntegTests.java new file mode 100644 index 0000000000000..0c9afbb7f535a --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedAzureBlobStoreRepositoryIntegTests.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.MockSecureSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.license.License; +import org.elasticsearch.license.LicenseService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.repositories.azure.AzureBlobStoreRepositoryTests; +import org.junit.BeforeClass; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.DEK_ROOT_CONTAINER; +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.getEncryptedBlobByteLength; +import static org.hamcrest.Matchers.hasSize; + +public final class EncryptedAzureBlobStoreRepositoryIntegTests extends AzureBlobStoreRepositoryTests { + private static List repositoryNames; + + @BeforeClass + private static void preGenerateRepositoryNames() { + List names = new ArrayList<>(); + for (int i = 0; i < 32; i++) { + names.add("test-repo-" + i); + } + repositoryNames = Collections.synchronizedList(names); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + Settings.Builder settingsBuilder = Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), License.LicenseType.TRIAL.getTypeName()); + MockSecureSettings superSecureSettings = (MockSecureSettings) settingsBuilder.getSecureSettings(); + superSecureSettings.merge(nodeSecureSettings()); + return settingsBuilder.build(); + } + + protected MockSecureSettings nodeSecureSettings() { + MockSecureSettings secureSettings = new MockSecureSettings(); + for (String repositoryName : repositoryNames) { + secureSettings.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + String.format(Locale.ROOT, "%14s", repositoryName) // pad to the minimum pass length of 112 bits (14) + ); + } + return secureSettings; + } + + @Override + protected String randomRepositoryName() { + return repositoryNames.remove(randomIntBetween(0, repositoryNames.size() - 1)); + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateEncryptedRepositoryPlugin.class, TestAzureRepositoryPlugin.class); + } + + @Override + protected String repositoryType() { + return EncryptedRepositoryPlugin.REPOSITORY_TYPE_NAME; + } + + @Override + protected Settings repositorySettings(String repositoryName) { + return Settings.builder() + .put(super.repositorySettings(repositoryName)) + .put(EncryptedRepositoryPlugin.DELEGATE_TYPE_SETTING.getKey(), "azure") + .put(EncryptedRepositoryPlugin.PASSWORD_NAME_SETTING.getKey(), repositoryName) + .build(); + } + + @Override + protected void assertEmptyRepo(Map blobsMap) { + List blobs = blobsMap.keySet() + .stream() + .filter(blob -> false == blob.contains("index")) + .filter(blob -> false == blob.contains(DEK_ROOT_CONTAINER)) // encryption metadata "leaks" + .collect(Collectors.toList()); + assertThat("Only index blobs should remain in repository but found " + blobs, blobs, hasSize(0)); + } + + @Override + protected long blobLengthFromContentLength(long contentLength) { + return getEncryptedBlobByteLength(contentLength); + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedFSBlobStoreRepositoryIntegTests.java b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedFSBlobStoreRepositoryIntegTests.java new file mode 100644 index 0000000000000..8df7aa5e617b7 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedFSBlobStoreRepositoryIntegTests.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.action.ActionRunnable; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.MockSecureSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.license.License; +import org.elasticsearch.license.LicenseService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.repositories.RepositoriesService; +import org.elasticsearch.repositories.RepositoryData; +import org.elasticsearch.repositories.RepositoryException; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; +import org.elasticsearch.repositories.blobstore.ESFsBasedRepositoryIntegTestCase; +import org.elasticsearch.repositories.fs.FsRepository; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Locale; +import java.util.stream.Stream; + +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.getEncryptedBlobByteLength; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.containsString; + +public final class EncryptedFSBlobStoreRepositoryIntegTests extends ESFsBasedRepositoryIntegTestCase { + private static int NUMBER_OF_TEST_REPOSITORIES = 32; + + private static List repositoryNames = new ArrayList<>(); + + @BeforeClass + private static void preGenerateRepositoryNames() { + for (int i = 0; i < NUMBER_OF_TEST_REPOSITORIES; i++) { + repositoryNames.add("test-repo-" + i); + } + } + + @Override + protected Settings repositorySettings(String repositoryName) { + final Settings.Builder settings = Settings.builder() + .put("compress", randomBoolean()) + .put("location", randomRepoPath()) + .put("delegate_type", FsRepository.TYPE) + .put("password_name", repositoryName); + if (randomBoolean()) { + long size = 1 << randomInt(10); + settings.put("chunk_size", new ByteSizeValue(size, ByteSizeUnit.KB)); + } + return settings.build(); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), License.LicenseType.TRIAL.getTypeName()) + .setSecureSettings(nodeSecureSettings()) + .build(); + } + + protected MockSecureSettings nodeSecureSettings() { + MockSecureSettings secureSettings = new MockSecureSettings(); + for (String repositoryName : repositoryNames) { + secureSettings.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + String.format(Locale.ROOT, "%14s", repositoryName) // pad to the minimum pass length of 112 bits (14) + ); + } + return secureSettings; + } + + @Override + protected String randomRepositoryName() { + return repositoryNames.remove(randomIntBetween(0, repositoryNames.size() - 1)); + } + + @Override + protected long blobLengthFromContentLength(long contentLength) { + return getEncryptedBlobByteLength(contentLength); + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateEncryptedRepositoryPlugin.class); + } + + @Override + protected String repositoryType() { + return EncryptedRepositoryPlugin.REPOSITORY_TYPE_NAME; + } + + public void testTamperedEncryptionMetadata() throws Exception { + final String repoName = randomRepositoryName(); + final Path repoPath = randomRepoPath(); + final Settings repoSettings = Settings.builder().put(repositorySettings(repoName)).put("location", repoPath).build(); + createRepository(repoName, repoSettings, true); + + final String snapshotName = randomName(); + logger.info("--> create snapshot {}:{}", repoName, snapshotName); + client().admin().cluster().prepareCreateSnapshot(repoName, snapshotName).setWaitForCompletion(true).setIndices("other*").get(); + + assertAcked(client().admin().cluster().prepareDeleteRepository(repoName)); + createRepository(repoName, Settings.builder().put(repoSettings).put("readonly", randomBoolean()).build(), randomBoolean()); + + try (Stream rootContents = Files.list(repoPath.resolve(EncryptedRepository.DEK_ROOT_CONTAINER))) { + // tamper all DEKs + rootContents.filter(Files::isDirectory).forEach(DEKRootPath -> { + try (Stream contents = Files.list(DEKRootPath)) { + contents.filter(Files::isRegularFile).forEach(DEKPath -> { + try { + byte[] originalDEKBytes = Files.readAllBytes(DEKPath); + // tamper DEK + int tamperPos = randomIntBetween(0, originalDEKBytes.length - 1); + originalDEKBytes[tamperPos] ^= 0xFF; + Files.write(DEKPath, originalDEKBytes); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + final BlobStoreRepository blobStoreRepository = (BlobStoreRepository) internalCluster().getCurrentMasterNodeInstance( + RepositoriesService.class + ).repository(repoName); + RepositoryException e = expectThrows( + RepositoryException.class, + () -> PlainActionFuture.get( + f -> blobStoreRepository.threadPool().generic().execute(ActionRunnable.wrap(f, blobStoreRepository::getRepositoryData)) + ) + ); + assertThat(e.getMessage(), containsString("the encryption metadata in the repository has been corrupted")); + e = expectThrows( + RepositoryException.class, + () -> client().admin().cluster().prepareRestoreSnapshot(repoName, snapshotName).setWaitForCompletion(true).get() + ); + assertThat(e.getMessage(), containsString("the encryption metadata in the repository has been corrupted")); + } + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedGCSBlobStoreRepositoryIntegTests.java b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedGCSBlobStoreRepositoryIntegTests.java new file mode 100644 index 0000000000000..99e9ec8d7f6c4 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedGCSBlobStoreRepositoryIntegTests.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.MockSecureSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.license.License; +import org.elasticsearch.license.LicenseService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.repositories.gcs.GoogleCloudStorageBlobStoreRepositoryTests; +import org.junit.BeforeClass; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.DEK_ROOT_CONTAINER; +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.getEncryptedBlobByteLength; +import static org.hamcrest.Matchers.hasSize; + +public final class EncryptedGCSBlobStoreRepositoryIntegTests extends GoogleCloudStorageBlobStoreRepositoryTests { + + private static List repositoryNames; + + @BeforeClass + private static void preGenerateRepositoryNames() { + List names = new ArrayList<>(); + for (int i = 0; i < 32; i++) { + names.add("test-repo-" + i); + } + repositoryNames = Collections.synchronizedList(names); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + Settings.Builder settingsBuilder = Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), License.LicenseType.TRIAL.getTypeName()); + MockSecureSettings superSecureSettings = (MockSecureSettings) settingsBuilder.getSecureSettings(); + superSecureSettings.merge(nodeSecureSettings()); + return settingsBuilder.build(); + } + + protected MockSecureSettings nodeSecureSettings() { + MockSecureSettings secureSettings = new MockSecureSettings(); + for (String repositoryName : repositoryNames) { + secureSettings.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + String.format(Locale.ROOT, "%14s", repositoryName) // pad to the minimum pass length of 112 bits (14) + ); + } + return secureSettings; + } + + @Override + protected String randomRepositoryName() { + return repositoryNames.remove(randomIntBetween(0, repositoryNames.size() - 1)); + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateEncryptedRepositoryPlugin.class, TestGoogleCloudStoragePlugin.class); + } + + @Override + protected String repositoryType() { + return EncryptedRepositoryPlugin.REPOSITORY_TYPE_NAME; + } + + @Override + protected Settings repositorySettings(String repositoryName) { + return Settings.builder() + .put(super.repositorySettings(repositoryName)) + .put(EncryptedRepositoryPlugin.DELEGATE_TYPE_SETTING.getKey(), "gcs") + .put(EncryptedRepositoryPlugin.PASSWORD_NAME_SETTING.getKey(), repositoryName) + .build(); + } + + @Override + protected void assertEmptyRepo(Map blobsMap) { + List blobs = blobsMap.keySet() + .stream() + .filter(blob -> false == blob.contains("index")) + .filter(blob -> false == blob.contains(DEK_ROOT_CONTAINER)) // encryption metadata "leaks" + .collect(Collectors.toList()); + assertThat("Only index blobs should remain in repository but found " + blobs, blobs, hasSize(0)); + } + + @Override + protected long blobLengthFromContentLength(long contentLength) { + return getEncryptedBlobByteLength(contentLength); + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedRepositorySecretIntegTests.java b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedRepositorySecretIntegTests.java new file mode 100644 index 0000000000000..b2b501155a7f8 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedRepositorySecretIntegTests.java @@ -0,0 +1,800 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.action.ActionRunnable; +import org.elasticsearch.action.admin.cluster.repositories.verify.VerifyRepositoryResponse; +import org.elasticsearch.action.admin.cluster.snapshots.create.CreateSnapshotResponse; +import org.elasticsearch.action.admin.cluster.snapshots.get.GetSnapshotsResponse; +import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotResponse; +import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.SnapshotsInProgress; +import org.elasticsearch.common.settings.MockSecureSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.license.License; +import org.elasticsearch.license.LicenseService; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.repositories.RepositoriesService; +import org.elasticsearch.repositories.RepositoryData; +import org.elasticsearch.repositories.RepositoryException; +import org.elasticsearch.repositories.RepositoryMissingException; +import org.elasticsearch.repositories.RepositoryVerificationException; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; +import org.elasticsearch.repositories.fs.FsRepository; +import org.elasticsearch.snapshots.Snapshot; +import org.elasticsearch.snapshots.SnapshotInfo; +import org.elasticsearch.snapshots.SnapshotMissingException; +import org.elasticsearch.snapshots.SnapshotState; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.InternalTestCluster; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.Matchers.anyObject; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, autoManageMasterNodes = false) +public final class EncryptedRepositorySecretIntegTests extends ESIntegTestCase { + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateEncryptedRepositoryPlugin.class); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), License.LicenseType.TRIAL.getTypeName()) + .build(); + } + + public void testRepositoryCreationFailsForMissingPassword() throws Exception { + // if the password is missing on the master node, the repository creation fails + final String repositoryName = randomName(); + MockSecureSettings secureSettingsWithPassword = new MockSecureSettings(); + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + randomAlphaOfLength(20) + ); + logger.info("--> start 3 nodes"); + internalCluster().setBootstrapMasterNodeIndex(0); + final String masterNodeName = internalCluster().startNode(); + logger.info("--> started master node " + masterNodeName); + ensureStableCluster(1); + internalCluster().startNodes(2, Settings.builder().setSecureSettings(secureSettingsWithPassword).build()); + logger.info("--> started two other nodes"); + ensureStableCluster(3); + assertThat(masterNodeName, equalTo(internalCluster().getMasterName())); + + final Settings repositorySettings = repositorySettings(repositoryName); + RepositoryException e = expectThrows( + RepositoryException.class, + () -> client().admin() + .cluster() + .preparePutRepository(repositoryName) + .setType(repositoryType()) + .setVerify(randomBoolean()) + .setSettings(repositorySettings) + .get() + ); + assertThat(e.getMessage(), containsString("failed to create repository")); + expectThrows(RepositoryMissingException.class, () -> client().admin().cluster().prepareGetRepositories(repositoryName).get()); + + if (randomBoolean()) { + // stop the node with the missing password + internalCluster().stopRandomNode(InternalTestCluster.nameFilter(masterNodeName)); + ensureStableCluster(2); + } else { + // restart the node with the missing password + internalCluster().restartNode(masterNodeName, new InternalTestCluster.RestartCallback() { + @Override + public Settings onNodeStopped(String nodeName) throws Exception { + Settings.Builder newSettings = Settings.builder().put(super.onNodeStopped(nodeName)); + newSettings.setSecureSettings(secureSettingsWithPassword); + return newSettings.build(); + } + }); + ensureStableCluster(3); + } + // repository creation now successful + createRepository(repositoryName, repositorySettings, true); + } + + public void testRepositoryVerificationFailsForMissingPassword() throws Exception { + // if the password is missing on any non-master node, the repository verification fails + final String repositoryName = randomName(); + MockSecureSettings secureSettingsWithPassword = new MockSecureSettings(); + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + randomAlphaOfLength(20) + ); + logger.info("--> start 2 nodes"); + internalCluster().setBootstrapMasterNodeIndex(0); + final String masterNodeName = internalCluster().startNode(Settings.builder().setSecureSettings(secureSettingsWithPassword).build()); + logger.info("--> started master node " + masterNodeName); + ensureStableCluster(1); + final String otherNodeName = internalCluster().startNode(); + logger.info("--> started other node " + otherNodeName); + ensureStableCluster(2); + assertThat(masterNodeName, equalTo(internalCluster().getMasterName())); + // repository create fails verification + final Settings repositorySettings = repositorySettings(repositoryName); + expectThrows( + RepositoryVerificationException.class, + () -> client().admin() + .cluster() + .preparePutRepository(repositoryName) + .setType(repositoryType()) + .setVerify(true) + .setSettings(repositorySettings) + .get() + ); + if (randomBoolean()) { + // delete and recreate repo + logger.debug("--> deleting repository [name: {}]", repositoryName); + assertAcked(client().admin().cluster().prepareDeleteRepository(repositoryName)); + assertAcked( + client().admin() + .cluster() + .preparePutRepository(repositoryName) + .setType(repositoryType()) + .setVerify(false) + .setSettings(repositorySettings) + .get() + ); + } + // test verify call fails + expectThrows(RepositoryVerificationException.class, () -> client().admin().cluster().prepareVerifyRepository(repositoryName).get()); + if (randomBoolean()) { + // stop the node with the missing password + internalCluster().stopRandomNode(InternalTestCluster.nameFilter(otherNodeName)); + ensureStableCluster(1); + // repository verification now succeeds + VerifyRepositoryResponse verifyRepositoryResponse = client().admin().cluster().prepareVerifyRepository(repositoryName).get(); + List verifiedNodes = verifyRepositoryResponse.getNodes().stream().map(n -> n.getName()).collect(Collectors.toList()); + assertThat(verifiedNodes, contains(masterNodeName)); + } else { + // restart the node with the missing password + internalCluster().restartNode(otherNodeName, new InternalTestCluster.RestartCallback() { + @Override + public Settings onNodeStopped(String nodeName) throws Exception { + Settings.Builder newSettings = Settings.builder().put(super.onNodeStopped(nodeName)); + newSettings.setSecureSettings(secureSettingsWithPassword); + return newSettings.build(); + } + }); + ensureStableCluster(2); + // repository verification now succeeds + VerifyRepositoryResponse verifyRepositoryResponse = client().admin().cluster().prepareVerifyRepository(repositoryName).get(); + List verifiedNodes = verifyRepositoryResponse.getNodes().stream().map(n -> n.getName()).collect(Collectors.toList()); + assertThat(verifiedNodes, containsInAnyOrder(masterNodeName, otherNodeName)); + } + } + + public void testRepositoryVerificationFailsForDifferentPassword() throws Exception { + final String repositoryName = randomName(); + final String repoPass1 = randomAlphaOfLength(20); + final String repoPass2 = randomAlphaOfLength(19); + // put a different repository password + MockSecureSettings secureSettings1 = new MockSecureSettings(); + secureSettings1.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + repoPass1 + ); + MockSecureSettings secureSettings2 = new MockSecureSettings(); + secureSettings2.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + repoPass2 + ); + logger.info("--> start 2 nodes"); + internalCluster().setBootstrapMasterNodeIndex(1); + final String node1 = internalCluster().startNode(Settings.builder().setSecureSettings(secureSettings1).build()); + final String node2 = internalCluster().startNode(Settings.builder().setSecureSettings(secureSettings2).build()); + ensureStableCluster(2); + // repository create fails verification + Settings repositorySettings = repositorySettings(repositoryName); + expectThrows( + RepositoryVerificationException.class, + () -> client().admin() + .cluster() + .preparePutRepository(repositoryName) + .setType(repositoryType()) + .setVerify(true) + .setSettings(repositorySettings) + .get() + ); + if (randomBoolean()) { + // delete and recreate repo + logger.debug("--> deleting repository [name: {}]", repositoryName); + assertAcked(client().admin().cluster().prepareDeleteRepository(repositoryName)); + assertAcked( + client().admin() + .cluster() + .preparePutRepository(repositoryName) + .setType(repositoryType()) + .setVerify(false) + .setSettings(repositorySettings) + .get() + ); + } + // test verify call fails + expectThrows(RepositoryVerificationException.class, () -> client().admin().cluster().prepareVerifyRepository(repositoryName).get()); + // restart one of the nodes to use the same password + if (randomBoolean()) { + secureSettings1.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + repoPass2 + ); + internalCluster().restartNode(node1, new InternalTestCluster.RestartCallback()); + } else { + secureSettings2.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + repoPass1 + ); + internalCluster().restartNode(node2, new InternalTestCluster.RestartCallback()); + } + ensureStableCluster(2); + // repository verification now succeeds + VerifyRepositoryResponse verifyRepositoryResponse = client().admin().cluster().prepareVerifyRepository(repositoryName).get(); + List verifiedNodes = verifyRepositoryResponse.getNodes().stream().map(n -> n.getName()).collect(Collectors.toList()); + assertThat(verifiedNodes, containsInAnyOrder(node1, node2)); + } + + public void testLicenseComplianceSnapshotAndRestore() throws Exception { + final String repositoryName = randomName(); + MockSecureSettings secureSettingsWithPassword = new MockSecureSettings(); + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + randomAlphaOfLength(20) + ); + logger.info("--> start 2 nodes"); + internalCluster().setBootstrapMasterNodeIndex(1); + internalCluster().startNodes(2, Settings.builder().setSecureSettings(secureSettingsWithPassword).build()); + ensureStableCluster(2); + + logger.info("--> creating repo " + repositoryName); + createRepository(repositoryName); + final String indexName = randomName(); + logger.info("--> create random index {} with {} records", indexName, 3); + indexRandom( + true, + client().prepareIndex(indexName, "_doc", "1").setSource("field1", "the quick brown fox jumps"), + client().prepareIndex(indexName, "_doc", "2").setSource("field1", "quick brown"), + client().prepareIndex(indexName, "_doc", "3").setSource("field1", "quick") + ); + assertHitCount(client().prepareSearch(indexName).setSize(0).get(), 3); + + final String snapshotName = randomName(); + logger.info("--> create snapshot {}:{}", repositoryName, snapshotName); + assertSuccessfulSnapshot( + client().admin() + .cluster() + .prepareCreateSnapshot(repositoryName, snapshotName) + .setIndices(indexName) + .setWaitForCompletion(true) + .get() + ); + + // make license not accept encrypted snapshots + EncryptedRepository encryptedRepository = (EncryptedRepository) internalCluster().getCurrentMasterNodeInstance( + RepositoriesService.class + ).repository(repositoryName); + encryptedRepository.licenseStateSupplier = () -> { + XPackLicenseState mockLicenseState = mock(XPackLicenseState.class); + when(mockLicenseState.isAllowed(anyObject())).thenReturn(false); + return mockLicenseState; + }; + + // now snapshot is not permitted + ElasticsearchSecurityException e = expectThrows( + ElasticsearchSecurityException.class, + () -> client().admin().cluster().prepareCreateSnapshot(repositoryName, snapshotName + "2").setWaitForCompletion(true).get() + ); + assertThat(e.getDetailedMessage(), containsString("current license is non-compliant for [encrypted snapshots]")); + + logger.info("--> delete index {}", indexName); + assertAcked(client().admin().indices().prepareDelete(indexName)); + + // but restore is permitted + logger.info("--> restore index from the snapshot"); + assertSuccessfulRestore( + client().admin().cluster().prepareRestoreSnapshot(repositoryName, snapshotName).setWaitForCompletion(true).get() + ); + ensureGreen(); + assertHitCount(client().prepareSearch(indexName).setSize(0).get(), 3); + // also delete snapshot is permitted + logger.info("--> delete snapshot {}:{}", repositoryName, snapshotName); + assertAcked(client().admin().cluster().prepareDeleteSnapshot(repositoryName, snapshotName).get()); + } + + public void testSnapshotIsPartialForMissingPassword() throws Exception { + final String repositoryName = randomName(); + final Settings repositorySettings = repositorySettings(repositoryName); + MockSecureSettings secureSettingsWithPassword = new MockSecureSettings(); + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + randomAlphaOfLength(20) + ); + logger.info("--> start 2 nodes"); + internalCluster().setBootstrapMasterNodeIndex(0); + // master has the password + internalCluster().startNode(Settings.builder().setSecureSettings(secureSettingsWithPassword).build()); + ensureStableCluster(1); + final String otherNode = internalCluster().startNode(); + ensureStableCluster(2); + logger.debug("--> creating repository [name: {}, verify: {}, settings: {}]", repositoryName, false, repositorySettings); + assertAcked( + client().admin() + .cluster() + .preparePutRepository(repositoryName) + .setType(repositoryType()) + .setVerify(false) + .setSettings(repositorySettings) + ); + // create an index with the shard on the node without a repository password + final String indexName = randomName(); + final Settings indexSettings = Settings.builder() + .put(indexSettings()) + .put("index.routing.allocation.include._name", otherNode) + .put(SETTING_NUMBER_OF_SHARDS, 1) + .build(); + logger.info("--> create random index {}", indexName); + createIndex(indexName, indexSettings); + indexRandom( + true, + client().prepareIndex(indexName, "_doc", "1").setSource("field1", "the quick brown fox jumps"), + client().prepareIndex(indexName, "_doc", "2").setSource("field1", "quick brown"), + client().prepareIndex(indexName, "_doc", "3").setSource("field1", "quick") + ); + assertHitCount(client().prepareSearch(indexName).setSize(0).get(), 3); + + // empty snapshot completes successfully because it does not involve data on the node without a repository password + final String snapshotName = randomName(); + logger.info("--> create snapshot {}:{}", repositoryName, snapshotName); + CreateSnapshotResponse createSnapshotResponse = client().admin() + .cluster() + .prepareCreateSnapshot(repositoryName, snapshotName) + .setIndices(indexName + "other*") + .setWaitForCompletion(true) + .get(); + assertThat( + createSnapshotResponse.getSnapshotInfo().successfulShards(), + equalTo(createSnapshotResponse.getSnapshotInfo().totalShards()) + ); + assertThat(createSnapshotResponse.getSnapshotInfo().successfulShards(), equalTo(0)); + assertThat( + createSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_HASH_USER_METADATA_KEY)) + ); + assertThat( + createSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_SALT_USER_METADATA_KEY)) + ); + + // snapshot is PARTIAL because it includes shards on nodes with a missing repository password + final String snapshotName2 = snapshotName + "2"; + CreateSnapshotResponse incompleteSnapshotResponse = client().admin() + .cluster() + .prepareCreateSnapshot(repositoryName, snapshotName2) + .setWaitForCompletion(true) + .setIndices(indexName) + .get(); + assertThat(incompleteSnapshotResponse.getSnapshotInfo().state(), equalTo(SnapshotState.PARTIAL)); + assertTrue( + incompleteSnapshotResponse.getSnapshotInfo() + .shardFailures() + .stream() + .allMatch(shardFailure -> shardFailure.reason().contains("[" + repositoryName + "] missing")) + ); + assertThat( + incompleteSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_HASH_USER_METADATA_KEY)) + ); + assertThat( + incompleteSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_SALT_USER_METADATA_KEY)) + ); + final Set nodesWithFailures = incompleteSnapshotResponse.getSnapshotInfo() + .shardFailures() + .stream() + .map(sf -> sf.nodeId()) + .collect(Collectors.toSet()); + assertThat(nodesWithFailures.size(), equalTo(1)); + final ClusterStateResponse clusterState = client().admin().cluster().prepareState().clear().setNodes(true).get(); + assertThat(clusterState.getState().nodes().get(nodesWithFailures.iterator().next()).getName(), equalTo(otherNode)); + } + + public void testSnapshotIsPartialForDifferentPassword() throws Exception { + final String repoName = randomName(); + final Settings repoSettings = repositorySettings(repoName); + final String repoPass1 = randomAlphaOfLength(20); + final String repoPass2 = randomAlphaOfLength(19); + MockSecureSettings secureSettingsMaster = new MockSecureSettings(); + secureSettingsMaster.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repoName).getKey(), + repoPass1 + ); + MockSecureSettings secureSettingsOther = new MockSecureSettings(); + secureSettingsOther.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repoName).getKey(), + repoPass2 + ); + final boolean putRepoEarly = randomBoolean(); + logger.info("--> start 2 nodes"); + internalCluster().setBootstrapMasterNodeIndex(0); + final String masterNode = internalCluster().startNode(Settings.builder().setSecureSettings(secureSettingsMaster).build()); + ensureStableCluster(1); + if (putRepoEarly) { + createRepository(repoName, repoSettings, true); + } + final String otherNode = internalCluster().startNode(Settings.builder().setSecureSettings(secureSettingsOther).build()); + ensureStableCluster(2); + if (false == putRepoEarly) { + createRepository(repoName, repoSettings, false); + } + + // create index with shards on both nodes + final String indexName = randomName(); + final Settings indexSettings = Settings.builder().put(indexSettings()).put(SETTING_NUMBER_OF_SHARDS, 5).build(); + logger.info("--> create random index {}", indexName); + createIndex(indexName, indexSettings); + indexRandom( + true, + client().prepareIndex(indexName, "_doc", "1").setSource("field1", "the quick brown fox jumps"), + client().prepareIndex(indexName, "_doc", "2").setSource("field1", "quick brown"), + client().prepareIndex(indexName, "_doc", "3").setSource("field1", "quick"), + client().prepareIndex(indexName, "_doc", "4").setSource("field1", "lazy"), + client().prepareIndex(indexName, "_doc", "5").setSource("field1", "dog") + ); + assertHitCount(client().prepareSearch(indexName).setSize(0).get(), 5); + + // empty snapshot completes successfully for both repos because it does not involve any data + final String snapshotName = randomName(); + logger.info("--> create snapshot {}:{}", repoName, snapshotName); + CreateSnapshotResponse createSnapshotResponse = client().admin() + .cluster() + .prepareCreateSnapshot(repoName, snapshotName) + .setIndices(indexName + "other*") + .setWaitForCompletion(true) + .get(); + assertThat( + createSnapshotResponse.getSnapshotInfo().successfulShards(), + equalTo(createSnapshotResponse.getSnapshotInfo().totalShards()) + ); + assertThat(createSnapshotResponse.getSnapshotInfo().successfulShards(), equalTo(0)); + assertThat( + createSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_HASH_USER_METADATA_KEY)) + ); + assertThat( + createSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_SALT_USER_METADATA_KEY)) + ); + + // snapshot is PARTIAL because it includes shards on nodes with a different repository KEK + final String snapshotName2 = snapshotName + "2"; + CreateSnapshotResponse incompleteSnapshotResponse = client().admin() + .cluster() + .prepareCreateSnapshot(repoName, snapshotName2) + .setWaitForCompletion(true) + .setIndices(indexName) + .get(); + assertThat(incompleteSnapshotResponse.getSnapshotInfo().state(), equalTo(SnapshotState.PARTIAL)); + assertTrue( + incompleteSnapshotResponse.getSnapshotInfo() + .shardFailures() + .stream() + .allMatch(shardFailure -> shardFailure.reason().contains("Repository password mismatch")) + ); + final Set nodesWithFailures = incompleteSnapshotResponse.getSnapshotInfo() + .shardFailures() + .stream() + .map(sf -> sf.nodeId()) + .collect(Collectors.toSet()); + assertThat(nodesWithFailures.size(), equalTo(1)); + final ClusterStateResponse clusterState = client().admin().cluster().prepareState().clear().setNodes(true).get(); + assertThat(clusterState.getState().nodes().get(nodesWithFailures.iterator().next()).getName(), equalTo(otherNode)); + } + + public void testWrongRepositoryPassword() throws Exception { + final String repositoryName = randomName(); + final Settings repositorySettings = repositorySettings(repositoryName); + final String goodPassword = randomAlphaOfLength(20); + final String wrongPassword = randomAlphaOfLength(19); + MockSecureSettings secureSettingsWithPassword = new MockSecureSettings(); + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + goodPassword + ); + logger.info("--> start 2 nodes"); + internalCluster().setBootstrapMasterNodeIndex(1); + internalCluster().startNodes(2, Settings.builder().setSecureSettings(secureSettingsWithPassword).build()); + ensureStableCluster(2); + createRepository(repositoryName, repositorySettings, true); + // create empty smapshot + final String snapshotName = randomName(); + logger.info("--> create empty snapshot {}:{}", repositoryName, snapshotName); + CreateSnapshotResponse createSnapshotResponse = client().admin() + .cluster() + .prepareCreateSnapshot(repositoryName, snapshotName) + .setWaitForCompletion(true) + .get(); + assertThat( + createSnapshotResponse.getSnapshotInfo().successfulShards(), + equalTo(createSnapshotResponse.getSnapshotInfo().totalShards()) + ); + assertThat(createSnapshotResponse.getSnapshotInfo().successfulShards(), equalTo(0)); + assertThat( + createSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_HASH_USER_METADATA_KEY)) + ); + assertThat( + createSnapshotResponse.getSnapshotInfo().userMetadata(), + not(hasKey(EncryptedRepository.PASSWORD_SALT_USER_METADATA_KEY)) + ); + // restart master node and fill in a wrong password + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + wrongPassword + ); + Set nodesWithWrongPassword = new HashSet<>(); + do { + String masterNodeName = internalCluster().getMasterName(); + logger.info("--> restart master node {}", masterNodeName); + internalCluster().restartNode(masterNodeName, new InternalTestCluster.RestartCallback()); + nodesWithWrongPassword.add(masterNodeName); + ensureStableCluster(2); + } while (false == nodesWithWrongPassword.contains(internalCluster().getMasterName())); + // maybe recreate the repository + if (randomBoolean()) { + deleteRepository(repositoryName); + createRepository(repositoryName, repositorySettings, false); + } + // all repository operations return "repository password is incorrect", but the repository does not move to the corrupted state + final BlobStoreRepository blobStoreRepository = (BlobStoreRepository) internalCluster().getCurrentMasterNodeInstance( + RepositoriesService.class + ).repository(repositoryName); + RepositoryException e = expectThrows( + RepositoryException.class, + () -> PlainActionFuture.get( + f -> blobStoreRepository.threadPool().generic().execute(ActionRunnable.wrap(f, blobStoreRepository::getRepositoryData)) + ) + ); + assertThat(e.getCause().getMessage(), containsString("repository password is incorrect")); + e = expectThrows( + RepositoryException.class, + () -> client().admin().cluster().prepareCreateSnapshot(repositoryName, snapshotName + "2").setWaitForCompletion(true).get() + ); + assertThat(e.getCause().getMessage(), containsString("repository password is incorrect")); + e = expectThrows(RepositoryException.class, () -> client().admin().cluster().prepareGetSnapshots(repositoryName).get()); + assertThat(e.getCause().getMessage(), containsString("repository password is incorrect")); + e = expectThrows( + RepositoryException.class, + () -> client().admin().cluster().prepareRestoreSnapshot(repositoryName, snapshotName).setWaitForCompletion(true).get() + ); + assertThat(e.getCause().getMessage(), containsString("repository password is incorrect")); + e = expectThrows( + RepositoryException.class, + () -> client().admin().cluster().prepareDeleteSnapshot(repositoryName, snapshotName).get() + ); + assertThat(e.getCause().getMessage(), containsString("repository password is incorrect")); + // restart master node and fill in the good password + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + goodPassword + ); + do { + String masterNodeName = internalCluster().getMasterName(); + logger.info("--> restart master node {}", masterNodeName); + internalCluster().restartNode(masterNodeName, new InternalTestCluster.RestartCallback()); + nodesWithWrongPassword.remove(masterNodeName); + ensureStableCluster(2); + } while (nodesWithWrongPassword.contains(internalCluster().getMasterName())); + // ensure get snapshot works + GetSnapshotsResponse getSnapshotResponse = client().admin().cluster().prepareGetSnapshots(repositoryName).get(); + assertThat(getSnapshotResponse.getSnapshots(), hasSize(1)); + } + + public void testSnapshotFailsForMasterFailoverWithWrongPassword() throws Exception { + final String repoName = randomName(); + final Settings repoSettings = repositorySettings(repoName); + final String goodPass = randomAlphaOfLength(20); + final String wrongPass = randomAlphaOfLength(19); + MockSecureSettings secureSettingsWithPassword = new MockSecureSettings(); + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repoName).getKey(), + goodPass + ); + logger.info("--> start 4 nodes"); + internalCluster().setBootstrapMasterNodeIndex(0); + final String masterNode = internalCluster().startMasterOnlyNodes( + 1, + Settings.builder().setSecureSettings(secureSettingsWithPassword).build() + ).get(0); + final String otherNode = internalCluster().startDataOnlyNodes( + 1, + Settings.builder().setSecureSettings(secureSettingsWithPassword).build() + ).get(0); + ensureStableCluster(2); + secureSettingsWithPassword.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repoName).getKey(), + wrongPass + ); + internalCluster().startMasterOnlyNodes(2, Settings.builder().setSecureSettings(secureSettingsWithPassword).build()); + ensureStableCluster(4); + assertThat(internalCluster().getMasterName(), equalTo(masterNode)); + + logger.debug("--> creating repository [name: {}, verify: {}, settings: {}]", repoName, false, repoSettings); + assertAcked( + client().admin().cluster().preparePutRepository(repoName).setType(repositoryType()).setVerify(false).setSettings(repoSettings) + ); + // create index with just one shard on the "other" data node + final String indexName = randomName(); + final Settings indexSettings = Settings.builder() + .put(indexSettings()) + .put("index.routing.allocation.include._name", otherNode) + .put(SETTING_NUMBER_OF_SHARDS, 1) + .build(); + logger.info("--> create random index {}", indexName); + createIndex(indexName, indexSettings); + indexRandom( + true, + client().prepareIndex(indexName, "_doc", "1").setSource("field1", "the quick brown fox jumps"), + client().prepareIndex(indexName, "_doc", "2").setSource("field1", "quick brown"), + client().prepareIndex(indexName, "_doc", "3").setSource("field1", "quick"), + client().prepareIndex(indexName, "_doc", "4").setSource("field1", "lazy"), + client().prepareIndex(indexName, "_doc", "5").setSource("field1", "dog") + ); + assertHitCount(client().prepareSearch(indexName).setSize(0).get(), 5); + + // block shard snapshot on the data node + final LocalStateEncryptedRepositoryPlugin.TestEncryptedRepository otherNodeEncryptedRepo = + (LocalStateEncryptedRepositoryPlugin.TestEncryptedRepository) internalCluster().getInstance( + RepositoriesService.class, + otherNode + ).repository(repoName); + otherNodeEncryptedRepo.blockSnapshotShard(); + + final String snapshotName = randomName(); + logger.info("--> create snapshot {}:{}", repoName, snapshotName); + client().admin().cluster().prepareCreateSnapshot(repoName, snapshotName).setIndices(indexName).setWaitForCompletion(false).get(); + + // stop master + internalCluster().stopRandomNode(InternalTestCluster.nameFilter(masterNode)); + ensureStableCluster(3); + + otherNodeEncryptedRepo.unblockSnapshotShard(); + + // the failover master has the wrong password, snapshot fails + logger.info("--> waiting for completion"); + expectThrows(SnapshotMissingException.class, () -> { waitForCompletion(repoName, snapshotName, TimeValue.timeValueSeconds(60)); }); + } + + protected String randomName() { + return randomAlphaOfLength(randomIntBetween(1, 10)).toLowerCase(Locale.ROOT); + } + + protected String repositoryType() { + return EncryptedRepositoryPlugin.REPOSITORY_TYPE_NAME; + } + + protected Settings repositorySettings(String repositoryName) { + return Settings.builder() + .put("compress", randomBoolean()) + .put(EncryptedRepositoryPlugin.DELEGATE_TYPE_SETTING.getKey(), FsRepository.TYPE) + .put(EncryptedRepositoryPlugin.PASSWORD_NAME_SETTING.getKey(), repositoryName) + .put("location", randomRepoPath()) + .build(); + } + + protected String createRepository(final String name) { + return createRepository(name, true); + } + + protected String createRepository(final String name, final boolean verify) { + return createRepository(name, repositorySettings(name), verify); + } + + protected String createRepository(final String name, final Settings settings, final boolean verify) { + logger.debug("--> creating repository [name: {}, verify: {}, settings: {}]", name, verify, settings); + assertAcked( + client().admin().cluster().preparePutRepository(name).setType(repositoryType()).setVerify(verify).setSettings(settings) + ); + + internalCluster().getDataOrMasterNodeInstances(RepositoriesService.class).forEach(repositories -> { + assertThat(repositories.repository(name), notNullValue()); + assertThat(repositories.repository(name), instanceOf(BlobStoreRepository.class)); + assertThat(repositories.repository(name).isReadOnly(), is(settings.getAsBoolean("readonly", false))); + }); + + return name; + } + + protected void deleteRepository(final String name) { + logger.debug("--> deleting repository [name: {}]", name); + assertAcked(client().admin().cluster().prepareDeleteRepository(name)); + + internalCluster().getDataOrMasterNodeInstances(RepositoriesService.class).forEach(repositories -> { + RepositoryMissingException e = expectThrows(RepositoryMissingException.class, () -> repositories.repository(name)); + assertThat(e.repository(), equalTo(name)); + }); + } + + private void assertSuccessfulRestore(RestoreSnapshotResponse response) { + assertThat(response.getRestoreInfo().successfulShards(), greaterThan(0)); + assertThat(response.getRestoreInfo().successfulShards(), equalTo(response.getRestoreInfo().totalShards())); + } + + private void assertSuccessfulSnapshot(CreateSnapshotResponse response) { + assertThat(response.getSnapshotInfo().successfulShards(), greaterThan(0)); + assertThat(response.getSnapshotInfo().successfulShards(), equalTo(response.getSnapshotInfo().totalShards())); + assertThat(response.getSnapshotInfo().userMetadata(), not(hasKey(EncryptedRepository.PASSWORD_HASH_USER_METADATA_KEY))); + assertThat(response.getSnapshotInfo().userMetadata(), not(hasKey(EncryptedRepository.PASSWORD_SALT_USER_METADATA_KEY))); + } + + public SnapshotInfo waitForCompletion(String repository, String snapshotName, TimeValue timeout) throws InterruptedException { + long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < timeout.millis()) { + List snapshotInfos = client().admin() + .cluster() + .prepareGetSnapshots(repository) + .setSnapshots(snapshotName) + .get() + .getSnapshots(); + assertThat(snapshotInfos.size(), equalTo(1)); + if (snapshotInfos.get(0).state().completed()) { + // Make sure that snapshot clean up operations are finished + ClusterStateResponse stateResponse = client().admin().cluster().prepareState().get(); + SnapshotsInProgress snapshotsInProgress = stateResponse.getState().custom(SnapshotsInProgress.TYPE); + if (snapshotsInProgress == null) { + return snapshotInfos.get(0); + } else { + boolean found = false; + for (SnapshotsInProgress.Entry entry : snapshotsInProgress.entries()) { + final Snapshot curr = entry.snapshot(); + if (curr.getRepository().equals(repository) && curr.getSnapshotId().getName().equals(snapshotName)) { + found = true; + break; + } + } + if (found == false) { + return snapshotInfos.get(0); + } + } + } + Thread.sleep(100); + } + fail("Timeout!!!"); + return null; + } +} diff --git a/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedS3BlobStoreRepositoryIntegTests.java b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedS3BlobStoreRepositoryIntegTests.java new file mode 100644 index 0000000000000..c010c3e569431 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/internalClusterTest/java/org/elasticsearch/repositories/encrypted/EncryptedS3BlobStoreRepositoryIntegTests.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.MockSecureSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.license.License; +import org.elasticsearch.license.LicenseService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.repositories.s3.S3BlobStoreRepositoryTests; +import org.junit.BeforeClass; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.DEK_ROOT_CONTAINER; +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.getEncryptedBlobByteLength; +import static org.hamcrest.Matchers.hasSize; + +public final class EncryptedS3BlobStoreRepositoryIntegTests extends S3BlobStoreRepositoryTests { + private static List repositoryNames; + + @BeforeClass + private static void preGenerateRepositoryNames() { + List names = new ArrayList<>(); + for (int i = 0; i < 32; i++) { + names.add("test-repo-" + i); + } + repositoryNames = Collections.synchronizedList(names); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal) { + Settings.Builder settingsBuilder = Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), License.LicenseType.TRIAL.getTypeName()); + MockSecureSettings superSecureSettings = (MockSecureSettings) settingsBuilder.getSecureSettings(); + superSecureSettings.merge(nodeSecureSettings()); + return settingsBuilder.build(); + } + + protected MockSecureSettings nodeSecureSettings() { + MockSecureSettings secureSettings = new MockSecureSettings(); + for (String repositoryName : repositoryNames) { + secureSettings.setString( + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryName).getKey(), + String.format(Locale.ROOT, "%14s", repositoryName) // pad to the minimum pass length of 112 bits (14) + ); + } + return secureSettings; + } + + @Override + protected String randomRepositoryName() { + return repositoryNames.remove(randomIntBetween(0, repositoryNames.size() - 1)); + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateEncryptedRepositoryPlugin.class, TestS3RepositoryPlugin.class); + } + + @Override + protected String repositoryType() { + return EncryptedRepositoryPlugin.REPOSITORY_TYPE_NAME; + } + + @Override + protected Settings repositorySettings(String repositoryName) { + return Settings.builder() + .put(super.repositorySettings(repositoryName)) + .put(EncryptedRepositoryPlugin.DELEGATE_TYPE_SETTING.getKey(), "s3") + .put(EncryptedRepositoryPlugin.PASSWORD_NAME_SETTING.getKey(), repositoryName) + .build(); + } + + @Override + protected void assertEmptyRepo(Map blobsMap) { + List blobs = blobsMap.keySet() + .stream() + .filter(blob -> false == blob.contains("index")) + .filter(blob -> false == blob.contains(DEK_ROOT_CONTAINER)) // encryption metadata "leaks" + .collect(Collectors.toList()); + assertThat("Only index blobs should remain in repository but found " + blobs, blobs, hasSize(0)); + } + + @Override + protected long blobLengthFromContentLength(long contentLength) { + return getEncryptedBlobByteLength(contentLength); + } + + @Override + public void testEnforcedCooldownPeriod() { + // this test is not applicable for the encrypted repository because it verifies behavior which pertains to snapshots that must + // be created before the encrypted repository was introduced, hence no such encrypted snapshots can possibly exist + } +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/AESKeyUtils.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/AESKeyUtils.java new file mode 100644 index 0000000000000..92a128d93848c --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/AESKeyUtils.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.settings.SecureString; + +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import javax.crypto.SecretKeyFactory; +import javax.crypto.spec.PBEKeySpec; +import javax.crypto.spec.SecretKeySpec; +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; +import java.security.Key; +import java.util.Base64; + +public final class AESKeyUtils { + public static final int KEY_LENGTH_IN_BYTES = 32; // 256-bit AES key + public static final int WRAPPED_KEY_LENGTH_IN_BYTES = KEY_LENGTH_IN_BYTES + 8; // https://www.ietf.org/rfc/rfc3394.txt section 2.2 + // parameter for the KDF function, it's a funny and unusual iter count larger than 60k + private static final int KDF_ITER = 61616; + // the KDF algorithm that generate the symmetric key given the password + private static final String KDF_ALGO = "PBKDF2WithHmacSHA512"; + // The Id of any AES SecretKey is the AES-Wrap-ciphertext of this fixed 32 byte wide array. + // Key wrapping encryption is deterministic (same plaintext generates the same ciphertext) + // and the probability that two different keys map the same plaintext to the same ciphertext is very small + // (2^-256, much lower than the UUID collision of 2^-128), assuming AES is indistinguishable from a pseudorandom permutation. + private static final byte[] KEY_ID_PLAINTEXT = "wrapping known text forms key id".getBytes(StandardCharsets.UTF_8); + + public static byte[] wrap(SecretKey wrappingKey, SecretKey keyToWrap) throws GeneralSecurityException { + assert "AES".equals(wrappingKey.getAlgorithm()); + assert "AES".equals(keyToWrap.getAlgorithm()); + Cipher c = Cipher.getInstance("AESWrap"); + c.init(Cipher.WRAP_MODE, wrappingKey); + return c.wrap(keyToWrap); + } + + public static SecretKey unwrap(SecretKey wrappingKey, byte[] keyToUnwrap) throws GeneralSecurityException { + assert "AES".equals(wrappingKey.getAlgorithm()); + assert keyToUnwrap.length == WRAPPED_KEY_LENGTH_IN_BYTES; + Cipher c = Cipher.getInstance("AESWrap"); + c.init(Cipher.UNWRAP_MODE, wrappingKey); + Key unwrappedKey = c.unwrap(keyToUnwrap, "AES", Cipher.SECRET_KEY); + return new SecretKeySpec(unwrappedKey.getEncoded(), "AES"); // make sure unwrapped key is "AES" + } + + /** + * Computes the ID of the given AES {@code SecretKey}. + * The ID can be published as it does not leak any information about the key. + * Different {@code SecretKey}s have different IDs with a very high probability. + *

+ * The ID is the ciphertext of a known plaintext, using the AES Wrap cipher algorithm. + * AES Wrap algorithm is deterministic, i.e. encryption using the same key, of the same plaintext, generates the same ciphertext. + * Moreover, the ciphertext reveals no information on the key, and the probability of collision of ciphertexts given different + * keys is statistically negligible. + */ + public static String computeId(SecretKey secretAESKey) throws GeneralSecurityException { + byte[] ciphertextOfKnownPlaintext = wrap(secretAESKey, new SecretKeySpec(KEY_ID_PLAINTEXT, "AES")); + return new String(Base64.getUrlEncoder().withoutPadding().encode(ciphertextOfKnownPlaintext), StandardCharsets.UTF_8); + } + + public static SecretKey generatePasswordBasedKey(SecureString password, String salt) throws GeneralSecurityException { + return generatePasswordBasedKey(password, salt.getBytes(StandardCharsets.UTF_8)); + } + + public static SecretKey generatePasswordBasedKey(SecureString password, byte[] salt) throws GeneralSecurityException { + PBEKeySpec keySpec = new PBEKeySpec(password.getChars(), salt, KDF_ITER, KEY_LENGTH_IN_BYTES * Byte.SIZE); + SecretKeyFactory keyFactory = SecretKeyFactory.getInstance(KDF_ALGO); + SecretKey secretKey = keyFactory.generateSecret(keySpec); + SecretKeySpec secret = new SecretKeySpec(secretKey.getEncoded(), "AES"); + return secret; + } +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/BufferOnMarkInputStream.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/BufferOnMarkInputStream.java new file mode 100644 index 0000000000000..91b56f012075e --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/BufferOnMarkInputStream.java @@ -0,0 +1,563 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Locale; +import java.util.Objects; + +/** + * A {@code BufferOnMarkInputStream} adds the {@code mark} and {@code reset} functionality to another input stream. + * All the bytes read or skipped following a {@link #mark(int)} call are also stored in a fixed-size internal array + * so they can be replayed following a {@link #reset()} call. The size of the internal buffer is specified at construction + * time. It is an error (throws {@code IllegalArgumentException}) to specify a larger {@code readlimit} value as an argument + * to a {@code mark} call. + *

+ * Unlike the {@link java.io.BufferedInputStream} this only buffers upon a {@link #mark(int)} call, + * i.e. if {@code mark} is never called this is equivalent to a bare pass-through {@link FilterInputStream}. + * Moreover, this does not buffer in advance, so the amount of bytes read from this input stream, at any time, is equal to the amount + * read from the underlying stream (provided that reset has not been called, in which case bytes are replayed from the internal buffer + * and no bytes are read from the underlying stream). + *

+ * Close will also close the underlying stream and any subsequent {@code read}, {@code skip}, {@code available} and + * {@code reset} calls will throw {@code IOException}s. + *

+ * This is NOT thread-safe, multiple threads sharing a single instance must synchronize access. + */ +public final class BufferOnMarkInputStream extends InputStream { + + /** + * the underlying input stream supplying the actual bytes to read + */ + final InputStream source; + /** + * The fixed capacity buffer used to store the bytes following a {@code mark} call on the input stream, + * and which are then replayed after the {@code reset} call. + * The buffer permits appending bytes which can then be read, possibly multiple times, by also + * supporting the mark and reset operations on its own. + * Reading will not discard the bytes just read. Subsequent reads will return the + * next bytes, but the bytes can be replayed by reading after calling {@code reset}. + * The {@code mark} operation is used to adjust the position of the reset return position to the current + * read position and also discard the bytes read before. + */ + final RingBuffer ringBuffer; // package-protected for tests + /** + * {@code true} when the result of a read or a skip from the underlying source stream must also be stored in the buffer + */ + boolean storeToBuffer; // package-protected for tests + /** + * {@code true} when the returned bytes must come from the buffer and not from the underlying source stream + */ + boolean replayFromBuffer; // package-protected for tests + /** + * {@code true} when this stream is closed and any further calls throw IOExceptions + */ + boolean closed; // package-protected for tests + + /** + * Creates a {@code BufferOnMarkInputStream} that buffers a maximum of {@code bufferSize} elements + * from the wrapped input stream {@code source} in order to support {@code mark} and {@code reset}. + * The {@code bufferSize} is the maximum value for the {@code mark} readlimit argument. + * + * @param source the underlying input buffer + * @param bufferSize the number of bytes that can be stored after a call to mark + */ + public BufferOnMarkInputStream(InputStream source, int bufferSize) { + this.source = source; + this.ringBuffer = new RingBuffer(bufferSize); + this.storeToBuffer = this.replayFromBuffer = false; + this.closed = false; + } + + /** + * Reads up to {@code len} bytes of data into an array of bytes from this + * input stream. If {@code len} is zero, then no bytes are read and {@code 0} + * is returned; otherwise, there is an attempt to read at least one byte. + * If the contents of the stream must be replayed following a {@code reset} + * call, the call will return buffered bytes which have been returned in a previous + * call. Otherwise it forwards the read call to the underlying source input stream. + * If no byte is available because there are no more bytes to replay following + * a reset (if a reset was called) and the underlying stream is exhausted, the + * value {@code -1} is returned; otherwise, at least one byte is read and stored + * into {@code b}, starting at offset {@code off}. + * + * @param b the buffer into which the data is read. + * @param off the start offset in the destination array {@code b} + * @param len the maximum number of bytes read. + * @return the total number of bytes read into the buffer, or + * {@code -1} if there is no more data because the end of + * the stream has been reached. + * @throws NullPointerException If {@code b} is {@code null}. + * @throws IndexOutOfBoundsException If {@code off} is negative, + * {@code len} is negative, or {@code len} is greater than + * {@code b.length - off} + * @throws IOException if this stream has been closed or an I/O error occurs on the underlying stream. + * @see java.io.InputStream#read(byte[], int, int) + */ + @Override + public int read(byte[] b, int off, int len) throws IOException { + ensureOpen(); + // this is `Objects#checkFromIndexSize(off, len, b.length)` from Java 9 + if ((b.length | off | len) < 0 || len > b.length - off) { + throw new IndexOutOfBoundsException( + String.format(Locale.ROOT, "Range [%d, % ringBuffer.getAvailableToWriteByteCount()) { + // can not fully write to buffer + // invalidate mark + storeToBuffer = false; + // empty buffer + ringBuffer.clear(); + } else { + ringBuffer.write(b, off, bytesRead); + } + } + return bytesRead; + } + + /** + * Reads the next byte of data from this input stream. The value + * byte is returned as an {@code int} in the range + * {@code 0} to {@code 255}. If no byte is available + * because the end of the stream has been reached, the value + * {@code -1} is returned. The end of the stream is reached if the + * end of the underlying stream is reached, and reset has not been + * called or there are no more bytes to replay following a reset. + * This method blocks until input data is available, the end of + * the stream is detected, or an exception is thrown. + * + * @return the next byte of data, or {@code -1} if the end of the + * stream is reached. + * @exception IOException if this stream has been closed or an I/O error occurs on the underlying stream. + * @see BufferOnMarkInputStream#read(byte[], int, int) + */ + @Override + public int read() throws IOException { + ensureOpen(); + byte[] arr = new byte[1]; + int readResult = read(arr, 0, arr.length); + if (readResult == -1) { + return -1; + } + return arr[0]; + } + + /** + * Skips over and discards {@code n} bytes of data from the + * input stream. The {@code skip} method may, for a variety of + * reasons, end up skipping over some smaller number of bytes, + * possibly {@code 0}. The actual number of bytes skipped is + * returned. + * + * @param n the number of bytes to be skipped. + * @return the actual number of bytes skipped. + * @throws IOException if this stream is closed, or if {@code in.skip(n)} throws an IOException or, + * in the case that {@code mark} is called, if BufferOnMarkInputStream#read(byte[], int, int) throws an IOException + */ + @Override + public long skip(long n) throws IOException { + ensureOpen(); + if (n <= 0) { + return 0; + } + if (false == storeToBuffer) { + // integrity check of the replayFromBuffer state variable + if (replayFromBuffer) { + throw new IllegalStateException("Reset cannot be called without a preceding mark invocation"); + } + // if mark has not been called, no storing to the buffer is required + return source.skip(n); + } + long remaining = n; + int size = (int) Math.min(2048, remaining); + byte[] skipBuffer = new byte[size]; + while (remaining > 0) { + // skipping translates to a read so that the skipped bytes are stored in the buffer, + // so they can possibly be replayed after a reset + int bytesRead = read(skipBuffer, 0, (int) Math.min(size, remaining)); + if (bytesRead < 0) { + break; + } + remaining -= bytesRead; + } + return n - remaining; + } + + /** + * Returns an estimate of the number of bytes that can be read (or + * skipped over) from this input stream without blocking by the next + * caller of a method for this input stream. The next caller might be + * the same thread or another thread. A single read or skip of this + * many bytes will not block, but may read or skip fewer bytes. + * + * @return an estimate of the number of bytes that can be read (or skipped + * over) from this input stream without blocking. + * @exception IOException if this stream is closed or if {@code in.available()} throws an IOException + */ + @Override + public int available() throws IOException { + ensureOpen(); + int bytesAvailable = 0; + if (replayFromBuffer) { + bytesAvailable += ringBuffer.getAvailableToReadByteCount(); + } + bytesAvailable += source.available(); + return bytesAvailable; + } + + /** + * Tests if this input stream supports the {@code mark} and {@code reset} methods. + * This always returns {@code true}. + */ + @Override + public boolean markSupported() { + return true; + } + + /** + * Marks the current position in this input stream. A subsequent call to + * the {@code reset} method repositions this stream at the last marked + * position so that subsequent reads re-read the same bytes. The bytes + * read or skipped following a {@code mark} call will be buffered internally + * and any previously buffered bytes are discarded. + *

+ * The {@code readlimit} arguments tells this input stream to + * allow that many bytes to be read before the mark position can be + * invalidated. The {@code readlimit} argument value must be smaller than + * the {@code bufferSize} constructor argument value, as returned by + * {@link #getMaxMarkReadlimit()}. + *

+ * The invalidation of the mark position when the read count exceeds the read + * limit is not currently enforced. A mark position is invalidated when the + * read count exceeds the maximum read limit, as returned by + * {@link #getMaxMarkReadlimit()}. + * + * @param readlimit the maximum limit of bytes that can be read before + * the mark position can be invalidated. + * @see BufferOnMarkInputStream#reset() + * @see java.io.InputStream#mark(int) + */ + @Override + public void mark(int readlimit) { + // readlimit is otherwise ignored but this defensively fails if the caller is expecting to be able to mark/reset more than this + // instance can accommodate in the fixed ring buffer + if (readlimit > ringBuffer.getBufferSize()) { + throw new IllegalArgumentException( + "Readlimit value [" + readlimit + "] exceeds the maximum value of [" + ringBuffer.getBufferSize() + "]" + ); + } else if (readlimit < 0) { + throw new IllegalArgumentException("Readlimit value [" + readlimit + "] cannot be negative"); + } + if (closed) { + return; + } + // signal that further read or skipped bytes must be stored to the buffer + storeToBuffer = true; + if (replayFromBuffer) { + // the mark operation while replaying after a reset + // this only discards the previously buffered bytes before the current position + // as well as updates the mark position in the buffer + ringBuffer.mark(); + } else { + // any previously stored bytes are discarded because mark only has to retain bytes from this position on + ringBuffer.clear(); + } + } + + /** + * Repositions this stream to the position at the time the {@code mark} method was last called on this input stream. + * It throws an {@code IOException} if {@code mark} has not yet been called on this instance. + * Internally, this resets the buffer to the last mark position and signals that further reads (and skips) + * on this input stream must return bytes from the buffer and not from the underlying source stream. + * + * @throws IOException if the stream has been closed or the number of bytes + * read since the last mark call exceeded {@link #getMaxMarkReadlimit()} + * @see java.io.InputStream#mark(int) + */ + @Override + public void reset() throws IOException { + ensureOpen(); + if (false == storeToBuffer) { + throw new IOException("Mark not called or has been invalidated"); + } + // signal that further reads/skips must be satisfied from the buffer and not from the underlying source stream + replayFromBuffer = true; + // position the buffer's read pointer back to the last mark position + ringBuffer.reset(); + } + + /** + * Closes this input stream as well as the underlying stream. + * + * @exception IOException if an I/O error occurs while closing the underlying stream. + */ + @Override + public void close() throws IOException { + if (false == closed) { + closed = true; + source.close(); + } + } + + /** + * Returns the maximum value for the {@code readlimit} argument of the {@link #mark(int)} method. + * This is the value of the {@code bufferSize} constructor argument and represents the maximum number + * of bytes that can be internally buffered (so they can be replayed after the reset call). + */ + public int getMaxMarkReadlimit() { + return ringBuffer.getBufferSize(); + } + + private void ensureOpen() throws IOException { + if (closed) { + throw new IOException("Stream has been closed"); + } + } + + /** + * This buffer is used to store all the bytes read or skipped after the last {@link BufferOnMarkInputStream#mark(int)} + * invocation. + *

+ * The latest bytes written to the ring buffer are appended following the previous ones. + * Reading back the bytes advances an internal pointer so that subsequent read calls return subsequent bytes. + * However, read bytes are not discarded. The same bytes can be re-read following the {@link #reset()} invocation. + * {@link #reset()} permits re-reading the bytes since the last {@link #mark()}} call, or since the buffer instance + * has been created or the {@link #clear()} method has been invoked. + * Calling {@link #mark()} will discard all bytes read before, and calling {@link #clear()} will discard all the + * bytes (new bytes must be written otherwise reading will return {@code 0} bytes). + */ + static class RingBuffer { + + /** + * This holds the size of the buffer which is lazily allocated on the first {@link #write(byte[], int, int)} invocation + */ + private final int bufferSize; + /** + * The array used to store the bytes to be replayed upon a reset call. + */ + byte[] buffer; // package-protected for tests + /** + * The start offset (inclusive) for the bytes that must be re-read after a reset call. This offset is advanced + * by invoking {@link #mark()} + */ + int head; // package-protected for tests + /** + * The end offset (exclusive) for the bytes that must be re-read after a reset call. This offset is advanced + * by writing to the ring buffer. + */ + int tail; // package-protected for tests + /** + * The offset of the bytes to return on the next read call. This offset is advanced by reading from the ring buffer. + */ + int position; // package-protected for tests + + /** + * Creates a new ring buffer instance that can store a maximum of {@code bufferSize} bytes. + * More bytes are stored by writing to the ring buffer, and bytes are discarded from the buffer by the + * {@code mark} and {@code reset} method invocations. + */ + RingBuffer(int bufferSize) { + if (bufferSize <= 0) { + throw new IllegalArgumentException("The buffersize constructor argument must be a strictly positive value"); + } + this.bufferSize = bufferSize; + } + + /** + * Returns the maximum number of bytes that this buffer can store. + */ + int getBufferSize() { + return bufferSize; + } + + /** + * Rewind back to the read position of the last {@link #mark()} or {@link #reset()}. The next + * {@link RingBuffer#read(byte[], int, int)} call will return the same bytes that the read + * call after the last {@link #mark()} did. + */ + void reset() { + position = head; + } + + /** + * Mark the current read position. Any previously read bytes are discarded from the ring buffer, + * i.e. they cannot be re-read, but this frees up space for writing other bytes. + * All the following {@link RingBuffer#read(byte[], int, int)} calls will revert back to this position. + */ + void mark() { + head = position; + } + + /** + * Empties out the ring buffer, discarding all the bytes written to it, i.e. any following read calls don't + * return any bytes. + */ + void clear() { + head = position = tail = 0; + } + + /** + * Copies up to {@code len} bytes from the ring buffer and places them in the {@code b} array starting at offset {@code off}. + * This advances the internal pointer of the ring buffer so that a subsequent call will return the following bytes, not the + * same ones (see {@link #reset()}). + * Exactly {@code len} bytes are copied from the ring buffer, but no more than {@link #getAvailableToReadByteCount()}; i.e. + * if {@code len} is greater than the value returned by {@link #getAvailableToReadByteCount()} this reads all the remaining + * available bytes (which could be {@code 0}). + * This returns the exact count of bytes read (the minimum of {@code len} and the value of {@code #getAvailableToReadByteCount}). + * + * @param b the array where to place the bytes read + * @param off the offset in the array where to start placing the bytes read (i.e. first byte is stored at b[off]) + * @param len the maximum number of bytes to read + * @return the number of bytes actually read + */ + int read(byte[] b, int off, int len) { + Objects.requireNonNull(b); + // this is `Objects#checkFromIndexSize(off, len, b.length)` from Java 9 + if ((b.length | off | len) < 0 || len > b.length - off) { + throw new IndexOutOfBoundsException( + String.format(Locale.ROOT, "Range [%d, %exactly {@code len} bytes from the array {@code b}, starting at offset {@code off}, into the ring buffer. + * The bytes are appended after the ones written in the same way by a previous call, and are available to + * {@link #read(byte[], int, int)} immediately. + * This throws {@code IllegalArgumentException} if the ring buffer does not have enough space left. + * To get the available capacity left call {@link #getAvailableToWriteByteCount()}. + * + * @param b the array from which to copy the bytes into the ring buffer + * @param off the offset of the first element to copy + * @param len the number of elements to copy + */ + void write(byte[] b, int off, int len) { + Objects.requireNonNull(b); + // this is `Objects#checkFromIndexSize(off, len, b.length)` from Java 9 + if ((b.length | off | len) < 0 || len > b.length - off) { + throw new IndexOutOfBoundsException( + String.format(Locale.ROOT, "Range [%d, % 0) { + // "+ 1" for the full-buffer sentinel element + buffer = new byte[bufferSize + 1]; + head = position = tail = 0; + } + if (len > getAvailableToWriteByteCount()) { + throw new IllegalArgumentException("Not enough remaining space in the ring buffer"); + } + while (len > 0) { + final int writeLength; + if (head <= tail) { + writeLength = Math.min(len, buffer.length - tail - (head == 0 ? 1 : 0)); + } else { + writeLength = Math.min(len, head - tail - 1); + } + if (writeLength <= 0) { + throw new IllegalStateException("No space left in the ring buffer"); + } + System.arraycopy(b, off, buffer, tail, writeLength); + tail += writeLength; + off += writeLength; + len -= writeLength; + if (tail == buffer.length) { + tail = 0; + // tail wrap-around overwrites head + if (head == 0) { + throw new IllegalStateException("Possible overflow of the ring buffer"); + } + } + } + } + + /** + * Returns the number of bytes that can be written to this ring buffer before it becomes full + * and will not accept further writes. Be advised that reading (see {@link #read(byte[], int, int)}) + * does not free up space because bytes can be re-read multiple times (see {@link #reset()}); + * ring buffer space can be reclaimed by calling {@link #mark()} or {@link #clear()} + */ + int getAvailableToWriteByteCount() { + if (buffer == null) { + return bufferSize; + } + if (head == tail) { + return buffer.length - 1; + } else if (head < tail) { + return buffer.length - tail + head - 1; + } else { + return head - tail - 1; + } + } + + /** + * Returns the number of bytes that can be read from this ring buffer before it becomes empty + * and all subsequent {@link #read(byte[], int, int)} calls will return {@code 0}. Writing + * more bytes (see {@link #write(byte[], int, int)}) will obviously increase the number of + * bytes available to read. Calling {@link #reset()} will also increase the available byte + * count because the following reads will go over again the same bytes since the last + * {@code mark} call. + */ + int getAvailableToReadByteCount() { + if (buffer == null) { + return 0; + } + if (head <= tail) { + return tail - position; + } else if (position >= head) { + return buffer.length - position + tail; + } else { + return tail - position; + } + } + + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/ChainingInputStream.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/ChainingInputStream.java new file mode 100644 index 0000000000000..fed8b152bcdb5 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/ChainingInputStream.java @@ -0,0 +1,434 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.io.Streams; +import org.elasticsearch.core.internal.io.IOUtils; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.SequenceInputStream; +import java.util.Locale; +import java.util.Objects; + +/** + * A {@code ChainingInputStream} concatenates multiple component input streams into a + * single input stream. + * It starts reading from the first input stream until it's exhausted, whereupon + * it closes it and starts reading from the next one, until the last component input + * stream is exhausted. + *

+ * The implementing subclass provides the component input streams by implementing the + * {@link #nextComponent(InputStream)} method. This method receives the instance of the + * current input stream, which has been exhausted, and must return the next input stream, + * or {@code null} if there are no more component streams. + * The {@code ChainingInputStream} assumes ownership of the newly generated component input + * stream, i.e. components should not be used by other callers and they will be closed + * when they are exhausted or when the {@code ChainingInputStream} is closed. + *

+ * This stream does support {@code mark} and {@code reset} but it expects that the component + * streams also support it. When {@code mark} is invoked on the chaining input stream, the + * call is forwarded to the current input stream component and a reference to that component + * is stored internally. A {@code reset} invocation on the chaining input stream will then make the + * stored component the current component and will then call the {@code reset} on it. + * The {@link #nextComponent(InputStream)} method must be able to generate the same components + * anew, starting from the component of the {@code reset} call. + * If the component input streams do not support {@code mark}/{@code reset} or + * {@link #nextComponent(InputStream)} cannot generate the same component multiple times, + * the implementing subclass must override {@link #markSupported()} to return {@code false}. + *

+ * The {@code close} call will close the current component input stream and any subsequent {@code read}, + * {@code skip}, {@code available} and {@code reset} calls will throw {@code IOException}s. + *

+ * The {@code ChainingInputStream} is similar in purpose to the {@link java.io.SequenceInputStream}, + * with the addition of {@code mark}/{@code reset} support. + *

+ * This is NOT thread-safe, multiple threads sharing a single instance must synchronize access. + */ +public abstract class ChainingInputStream extends InputStream { + + private static final Logger LOGGER = LogManager.getLogger(ChainingInputStream.class); + + /** + * value for the current input stream when there are no subsequent streams remaining, i.e. when + * {@link #nextComponent(InputStream)} returns {@code null} + * (this is `InputStream#nullInputStream` from Java9 + */ + protected static final InputStream EXHAUSTED_MARKER = new ByteArrayInputStream(new byte[0]); // protected for tests + + /** + * The instance of the currently in use component input stream, + * i.e. the instance currently servicing the read and skip calls on the {@code ChainingInputStream} + */ + protected InputStream currentIn; // protected for tests + /** + * The instance of the component input stream at the time of the last {@code mark} call. + */ + protected InputStream markIn; // protected for tests + /** + * {@code true} if {@link #close()} has been called; any subsequent {@code read}, {@code skip} + * {@code available} and {@code reset} calls will throw {@code IOException}s + */ + private boolean closed; + + /** + * Returns a new {@link ChainingInputStream} that concatenates the bytes to be read from the first + * input stream with the bytes from the second input stream. The stream arguments must support + * the {@code mark} and {@code reset} operations; otherwise use {@link SequenceInputStream}. + * + * @param first the input stream supplying the first bytes of the returned {@link ChainingInputStream} + * @param second the input stream supplying the bytes after the {@code first} input stream has been exhausted + */ + public static ChainingInputStream chain(InputStream first, InputStream second) { + if (false == Objects.requireNonNull(first).markSupported()) { + throw new IllegalArgumentException("The first component input stream does not support mark"); + } + if (false == Objects.requireNonNull(second).markSupported()) { + throw new IllegalArgumentException("The second component input stream does not support mark"); + } + // components can be reused, and the {@code ChainingInputStream} eagerly closes components after every use + // "first" and "second" are closed when the returned {@code ChainingInputStream} is closed + final InputStream firstComponent = Streams.noCloseStream(first); + final InputStream secondComponent = Streams.noCloseStream(second); + // be sure to remember the start of components because they might be reused + firstComponent.mark(Integer.MAX_VALUE); + secondComponent.mark(Integer.MAX_VALUE); + + return new ChainingInputStream() { + + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + // when returning the next component, start from its beginning + firstComponent.reset(); + return firstComponent; + } else if (currentComponentIn == firstComponent) { + // when returning the next component, start from its beginning + secondComponent.reset(); + return secondComponent; + } else if (currentComponentIn == secondComponent) { + return null; + } else { + throw new IllegalStateException("Unexpected component input stream"); + } + } + + @Override + public void close() throws IOException { + IOUtils.close(super::close, first, second); + } + }; + } + + /** + * This method is responsible for generating the component input streams. + * It is passed the current input stream and must return the successive one, + * or {@code null} if the current component is the last one. + * It is passed the {@code null} value at the very start, when no component + * input stream has yet been generated. + * The successive input stream returns the bytes (during reading) that should + * logically follow the bytes that have been previously returned by the passed-in + * {@code currentComponentIn}; i.e. the first {@code read} call on the next + * component returns the byte logically following the last byte of the previous + * component. + * In order to support {@code mark}/{@code reset} this method must be able + * to generate the successive input stream given any of the previously generated + * ones, i.e. implementors must not assume that the passed-in argument is the + * instance last returned by this method. Therefore, implementors must identify + * the bytes that the passed-in component generated and must return a new + * {@code InputStream} which returns the bytes that logically follow, even if + * the same sequence has been previously returned by another component. + * If this is not possible, and the implementation + * can only generate the component input streams once, it must override + * {@link #nextComponent(InputStream)} to return {@code false}. + */ + abstract @Nullable InputStream nextComponent(@Nullable InputStream currentComponentIn) throws IOException; + + /** + * Reads the next byte of data from this chaining input stream. + * The value byte is returned as an {@code int} in the range + * {@code 0} to {@code 255}. If no byte is available + * because the end of the stream has been reached, the value + * {@code -1} is returned. The end of the chaining input stream + * is reached when the end of the last component stream is reached. + * This method blocks until input data is available (possibly + * asking for the next input stream component), the end of + * the stream is detected, or an exception is thrown. + * + * @return the next byte of data, or {@code -1} if the end of the + * stream is reached. + * @exception IOException if this stream has been closed or + * an I/O error occurs on the current component stream. + * @see ChainingInputStream#read(byte[], int, int) + */ + @Override + public int read() throws IOException { + ensureOpen(); + do { + int byteVal = currentIn == null ? -1 : currentIn.read(); + if (byteVal != -1) { + return byteVal; + } + } while (nextIn()); + return -1; + } + + /** + * Reads up to {@code len} bytes of data into an array of bytes from this + * chaining input stream. If {@code len} is zero, then no bytes are read + * and {@code 0} is returned; otherwise, there is an attempt to read at least one byte. + * The {@code read} call is forwarded to the current component input stream. + * If the current component input stream is exhausted, the next one is obtained + * by invoking {@link #nextComponent(InputStream)} and the {@code read} call is + * forwarded to that. If the current component is exhausted + * and there is no subsequent component the value {@code -1} is returned; + * otherwise, at least one byte is read and stored into {@code b}, starting at + * offset {@code off}. + * + * @param b the buffer into which the data is read. + * @param off the start offset in the destination array {@code b} + * @param len the maximum number of bytes read. + * @return the total number of bytes read into the buffer, or + * {@code -1} if there is no more data because the current + * input stream component is exhausted and there is no next one + * {@link #nextComponent(InputStream)} retuns {@code null}. + * @throws NullPointerException If {@code b} is {@code null}. + * @throws IndexOutOfBoundsException If {@code off} is negative, + * {@code len} is negative, or {@code len} is greater than + * {@code b.length - off} + * @throws IOException if this stream has been closed or an I/O error + * occurs on the current component input stream. + * @see java.io.InputStream#read(byte[], int, int) + */ + @Override + public int read(byte[] b, int off, int len) throws IOException { + ensureOpen(); + // this is `Objects#checkFromIndexSize(off, len, b.length)` from Java 9 + if ((b.length | off | len) < 0 || len > b.length - off) { + throw new IndexOutOfBoundsException( + String.format(Locale.ROOT, "Range [%d, % 0) { + long bytesSkipped = currentIn.skip(bytesRemaining); + if (bytesSkipped == 0) { + int byteRead = read(); + if (byteRead == -1) { + break; + } else { + bytesRemaining--; + } + } else { + bytesRemaining -= bytesSkipped; + } + } + return n - bytesRemaining; + } + + /** + * Returns an estimate of the number of bytes that can be read (or + * skipped over) from this chaining input stream without blocking by the next + * caller of a method for this stream. The next caller might be + * the same thread or another thread. A single read or skip of this + * many bytes will not block, but may read or skip fewer bytes. + *

+ * This simply forwards the {@code available} call to the current + * component input stream, so the returned value is a conservative + * lower bound of the available byte count; i.e. it's possible that + * subsequent component streams have available bytes but this method + * only returns the available bytes of the current component. + * + * @return an estimate of the number of bytes that can be read (or skipped + * over) from this input stream without blocking. + * @exception IOException if this stream is closed or if + * {@code currentIn.available()} throws an IOException + */ + @Override + public int available() throws IOException { + ensureOpen(); + if (currentIn == null) { + nextIn(); + } + return currentIn.available(); + } + + /** + * Tests if this chaining input stream supports the {@code mark} and + * {@code reset} methods. By default this returns {@code true} but there + * are some requirements for how components are generated (see + * {@link #nextComponent(InputStream)}), in which case, if the implementer + * cannot satisfy them, it should override this to return {@code false}. + */ + @Override + public boolean markSupported() { + return true; + } + + /** + * Marks the current position in this input stream. A subsequent call to + * the {@code reset} method repositions this stream at the last marked + * position so that subsequent reads re-read the same bytes. + *

+ * The {@code readlimit} arguments tells this input stream to + * allow that many bytes to be read before the mark position can be + * invalidated. + *

+ * The {@code mark} call is forwarded to the current component input + * stream and a reference to it is stored internally. + * + * @param readlimit the maximum limit of bytes that can be read before + * the mark position can be invalidated. + * @see BufferOnMarkInputStream#reset() + * @see java.io.InputStream#mark(int) + */ + @Override + public void mark(int readlimit) { + if (markSupported() && false == closed) { + // closes any previously stored mark input stream + if (markIn != null && markIn != EXHAUSTED_MARKER && currentIn != markIn) { + try { + markIn.close(); + } catch (IOException e) { + // an IOException on a component input stream close is not important + LOGGER.info("IOException while closing a marked component input stream during a mark", e); + } + } + // stores the current input stream to be reused in case of a reset + markIn = currentIn; + if (markIn != null && markIn != EXHAUSTED_MARKER) { + markIn.mark(readlimit); + } + } + } + + /** + * Repositions this stream to the position at the time the + * {@code mark} method was last called on this chaining input stream, + * or at the beginning if the {@code mark} method was never called. + * Subsequent read calls will return the same bytes in the same + * order since the point of the {@code mark} call. Naturally, + * {@code mark} can be invoked at any moment, even after a + * {@code reset}. + *

+ * The previously stored reference to the current component during the + * {@code mark} invocation is made the new current component and then + * the {@code reset} call is forwarded to it. The next internal call to + * {@link #nextComponent(InputStream)} will use this component, so + * the {@link #nextComponent(InputStream)} must not assume monotonous + * arguments. + * + * @throws IOException if the stream has been closed or the number of bytes + * read since the last mark call exceeded the + * {@code readLimit} parameter + * @see java.io.InputStream#mark(int) + */ + @Override + public void reset() throws IOException { + ensureOpen(); + if (false == markSupported()) { + throw new IOException("Mark/reset not supported"); + } + if (currentIn != null && currentIn != EXHAUSTED_MARKER && currentIn != markIn) { + try { + currentIn.close(); + } catch (IOException e) { + // an IOException on a component input stream close is not important + LOGGER.info("IOException while closing the current component input stream during a reset", e); + } + } + currentIn = markIn; + if (currentIn != null && currentIn != EXHAUSTED_MARKER) { + currentIn.reset(); + } + } + + /** + * Closes this chaining input stream, closing the current component stream as well + * as any internally stored reference of a component during a {@code mark} call. + * + * @exception IOException if an I/O error occurs while closing the current or the marked stream. + */ + @Override + public void close() throws IOException { + if (false == closed) { + closed = true; + if (currentIn != null && currentIn != EXHAUSTED_MARKER) { + currentIn.close(); + } + if (markIn != null && markIn != currentIn && markIn != EXHAUSTED_MARKER) { + markIn.close(); + } + } + } + + private void ensureOpen() throws IOException { + if (closed) { + throw new IOException("Stream is closed"); + } + } + + private boolean nextIn() throws IOException { + if (currentIn == EXHAUSTED_MARKER) { + return false; + } + // close the current component, but only if it is not saved because of mark + if (currentIn != null && currentIn != markIn) { + currentIn.close(); + } + currentIn = nextComponent(currentIn); + if (currentIn == null) { + currentIn = EXHAUSTED_MARKER; + return false; + } + if (markSupported() && false == currentIn.markSupported()) { + throw new IllegalStateException("Component input stream must support mark"); + } + return true; + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/CountingInputStream.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/CountingInputStream.java new file mode 100644 index 0000000000000..91245beaa8707 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/CountingInputStream.java @@ -0,0 +1,115 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Objects; + +/** + * A {@code CountingInputStream} wraps another input stream and counts the number of bytes + * that have been read or skipped. + *

+ * This input stream does no buffering on its own and only supports {@code mark} and + * {@code reset} if the underlying wrapped stream supports it. + *

+ * If the stream supports {@code mark} and {@code reset} the byte count is also reset to the + * value that it had on the last {@code mark} call, thereby not counting the same bytes twice. + *

+ * If the {@code closeSource} constructor argument is {@code true}, closing this + * stream will also close the wrapped input stream. Apart from closing the wrapped + * stream in this case, the {@code close} method does nothing else. + */ +public final class CountingInputStream extends InputStream { + + private final InputStream source; + private final boolean closeSource; + long count; // package-protected for tests + long mark; // package-protected for tests + boolean closed; // package-protected for tests + + /** + * Wraps another input stream, counting the number of bytes read. + * + * @param source the input stream to be wrapped + * @param closeSource {@code true} if closing this stream will also close the wrapped stream + */ + public CountingInputStream(InputStream source, boolean closeSource) { + this.source = Objects.requireNonNull(source); + this.closeSource = closeSource; + this.count = 0L; + this.mark = -1L; + this.closed = false; + } + + /** Returns the number of bytes read. */ + public long getCount() { + return count; + } + + @Override + public int read() throws IOException { + int result = source.read(); + if (result != -1) { + count++; + } + return result; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + int result = source.read(b, off, len); + if (result != -1) { + count += result; + } + return result; + } + + @Override + public long skip(long n) throws IOException { + long result = source.skip(n); + count += result; + return result; + } + + @Override + public int available() throws IOException { + return source.available(); + } + + @Override + public boolean markSupported() { + return source.markSupported(); + } + + @Override + public synchronized void mark(int readlimit) { + source.mark(readlimit); + mark = count; + } + + @Override + public synchronized void reset() throws IOException { + if (false == source.markSupported()) { + throw new IOException("Mark not supported"); + } + if (mark == -1L) { + throw new IOException("Mark not set"); + } + count = mark; + source.reset(); + } + + @Override + public void close() throws IOException { + if (false == closed) { + closed = true; + if (closeSource) { + source.close(); + } + } + } +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/DecryptionPacketsInputStream.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/DecryptionPacketsInputStream.java new file mode 100644 index 0000000000000..635e5128d296c --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/DecryptionPacketsInputStream.java @@ -0,0 +1,195 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.util.ByteUtils; +import org.elasticsearch.core.internal.io.IOUtils; + +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.SecretKey; +import javax.crypto.ShortBufferException; +import javax.crypto.spec.GCMParameterSpec; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.Locale; +import java.util.Objects; + +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.GCM_IV_LENGTH_IN_BYTES; +import static org.elasticsearch.repositories.encrypted.EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; + +/** + * A {@code DecryptionPacketsInputStream} wraps an encrypted input stream and decrypts + * its contents. This is designed (and tested) to decrypt only the encryption format that + * {@link EncryptionPacketsInputStream} generates. No decrypted bytes are returned before + * they are authenticated. + *

+ * The same parameters, namely {@code secretKey} and {@code packetLength}, + * which have been used during encryption, must also be used for decryption, + * otherwise decryption will fail. + *

+ * This implementation buffers the encrypted packet in memory. The maximum packet size it can + * accommodate is {@link EncryptedRepository#MAX_PACKET_LENGTH_IN_BYTES}. + *

+ * This implementation does not support {@code mark} and {@code reset}. + *

+ * The {@code close} call will close the decryption input stream and any subsequent {@code read}, + * {@code skip}, {@code available} and {@code reset} calls will throw {@code IOException}s. + *

+ * This is NOT thread-safe, multiple threads sharing a single instance must synchronize access. + * + * @see EncryptionPacketsInputStream + */ +public final class DecryptionPacketsInputStream extends ChainingInputStream { + + private final InputStream source; + private final SecretKey secretKey; + private final int packetLength; + private final byte[] packetBuffer; + + private boolean hasNext; + private long counter; + + /** + * Computes and returns the length of the plaintext given the {@code ciphertextLength} and the {@code packetLength} + * used during encryption. + * Each ciphertext packet is prepended by the Initilization Vector and has the Authentication Tag appended. + * Decryption is 1:1, and the ciphertext is not padded, but stripping away the IV and the AT amounts to a shorter + * plaintext compared to the ciphertext. + * + * @see EncryptionPacketsInputStream#getEncryptionLength(long, int) + */ + public static long getDecryptionLength(long ciphertextLength, int packetLength) { + long encryptedPacketLength = packetLength + GCM_TAG_LENGTH_IN_BYTES + GCM_IV_LENGTH_IN_BYTES; + long completePackets = ciphertextLength / encryptedPacketLength; + long decryptedSize = completePackets * packetLength; + if (ciphertextLength % encryptedPacketLength != 0) { + decryptedSize += (ciphertextLength % encryptedPacketLength) - GCM_IV_LENGTH_IN_BYTES - GCM_TAG_LENGTH_IN_BYTES; + } + return decryptedSize; + } + + public DecryptionPacketsInputStream(InputStream source, SecretKey secretKey, int packetLength) { + this.source = Objects.requireNonNull(source); + this.secretKey = Objects.requireNonNull(secretKey); + if (packetLength <= 0 || packetLength >= EncryptedRepository.MAX_PACKET_LENGTH_IN_BYTES) { + throw new IllegalArgumentException("Invalid packet length [" + packetLength + "]"); + } + this.packetLength = packetLength; + this.packetBuffer = new byte[packetLength + GCM_TAG_LENGTH_IN_BYTES]; + this.hasNext = true; + this.counter = EncryptedRepository.PACKET_START_COUNTER; + } + + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn != null && currentComponentIn.read() != -1) { + throw new IllegalStateException("Stream for previous packet has not been fully processed"); + } + if (false == hasNext) { + return null; + } + PrefixInputStream packetInputStream = new PrefixInputStream( + source, + packetLength + GCM_IV_LENGTH_IN_BYTES + GCM_TAG_LENGTH_IN_BYTES, + false + ); + int currentPacketLength = decrypt(packetInputStream); + // only the last packet is shorter, so this must be the last packet + if (currentPacketLength != packetLength) { + hasNext = false; + } + return new ByteArrayInputStream(packetBuffer, 0, currentPacketLength); + } + + @Override + public boolean markSupported() { + return false; + } + + @Override + public void mark(int readlimit) {} + + @Override + public void reset() throws IOException { + throw new IOException("Mark/reset not supported"); + } + + @Override + public void close() throws IOException { + IOUtils.close(super::close, source); + } + + private int decrypt(PrefixInputStream packetInputStream) throws IOException { + // read only the IV prefix into the packet buffer + int ivLength = readNBytes(packetInputStream, packetBuffer, 0, GCM_IV_LENGTH_IN_BYTES); + if (ivLength != GCM_IV_LENGTH_IN_BYTES) { + throw new IOException("Packet heading IV error. Unexpected length [" + ivLength + "]."); + } + // extract the counter from the packet IV and validate it (that the packet is in order) + // skips the first 4 bytes in the packet IV, which contain the encryption nonce, which cannot be explicitly validated + // because the nonce is not passed in during decryption, but it is implicitly because it is part of the IV, + // when GCM validates the packet authn tag + long packetIvCounter = ByteUtils.readLongLE(packetBuffer, Integer.BYTES); + if (packetIvCounter != counter) { + throw new IOException("Packet counter mismatch. Expecting [" + counter + "], but got [" + packetIvCounter + "]."); + } + // counter increment for the subsequent packet + counter++; + // counter wrap around + if (counter == EncryptedRepository.PACKET_START_COUNTER) { + throw new IOException("Maximum packet count limit exceeded"); + } + // cipher used to decrypt only the current packetInputStream + Cipher packetCipher = getPacketDecryptionCipher(packetBuffer); + // read the rest of the packet, reusing the packetBuffer + int packetLength = readNBytes(packetInputStream, packetBuffer, 0, packetBuffer.length); + if (packetLength < GCM_TAG_LENGTH_IN_BYTES) { + throw new IOException("Encrypted packet is too short"); + } + try { + // in-place decryption of the whole packet and return decrypted length + return packetCipher.doFinal(packetBuffer, 0, packetLength, packetBuffer); + } catch (ShortBufferException | IllegalBlockSizeException | BadPaddingException e) { + throw new IOException("Exception during packet decryption", e); + } + } + + private Cipher getPacketDecryptionCipher(byte[] packet) throws IOException { + GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(GCM_TAG_LENGTH_IN_BYTES * Byte.SIZE, packet, 0, GCM_IV_LENGTH_IN_BYTES); + try { + Cipher packetCipher = Cipher.getInstance(EncryptedRepository.DATA_ENCRYPTION_SCHEME); + packetCipher.init(Cipher.DECRYPT_MODE, secretKey, gcmParameterSpec); + return packetCipher; + } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | InvalidAlgorithmParameterException e) { + throw new IOException("Exception during packet cipher initialisation", e); + } + } + + // this is `InputStream#readNBytes` from Java 9 + private static int readNBytes(InputStream inputStream, byte[] b, int off, int len) throws IOException { + if ((b.length | off | len) < 0 || len > b.length - off) { + throw new IndexOutOfBoundsException( + String.format(Locale.ROOT, "Range [%d, %> dekGenerator; + // license is checked before every snapshot operations; protected non-final for tests + protected Supplier licenseStateSupplier; + private final SecureString repositoryPassword; + private final String localRepositoryPasswordHash; + private final String localRepositoryPasswordSalt; + private volatile String validatedLocalRepositoryPasswordHash; + private final Cache dekCache; + + /** + * When set to true metadata files are stored in compressed format. This setting doesn’t affect index + * files that are already compressed by default. Defaults to false. + */ + static final Setting COMPRESS_SETTING = Setting.boolSetting("compress", false); + + /** + * Returns the byte length (i.e. the storage size) of an encrypted blob, given the length of the blob's plaintext contents. + * + * @see EncryptionPacketsInputStream#getEncryptionLength(long, int) + */ + public static long getEncryptedBlobByteLength(long plaintextBlobByteLength) { + return (long) DEK_ID_LENGTH /* UUID byte length */ + + EncryptionPacketsInputStream.getEncryptionLength(plaintextBlobByteLength, PACKET_LENGTH_IN_BYTES); + } + + protected EncryptedRepository( + RepositoryMetadata metadata, + NamedXContentRegistry namedXContentRegistry, + ClusterService clusterService, + BigArrays bigArrays, + RecoverySettings recoverySettings, + BlobStoreRepository delegatedRepository, + Supplier licenseStateSupplier, + SecureString repositoryPassword + ) throws GeneralSecurityException { + super(metadata, COMPRESS_SETTING.get(metadata.settings()), namedXContentRegistry, clusterService, bigArrays, recoverySettings); + this.delegatedRepository = delegatedRepository; + this.dekGenerator = createDEKGenerator(); + this.licenseStateSupplier = licenseStateSupplier; + this.repositoryPassword = repositoryPassword; + // the salt used to generate an irreversible "hash"; it is generated randomly but it's fixed for the lifetime of the + // repository solely for efficiency reasons + this.localRepositoryPasswordSalt = UUIDs.randomBase64UUID(); + // the "hash" of the repository password from the local node is not actually a hash but the ciphertext of a + // known-plaintext using a key derived from the repository password using a random salt + this.localRepositoryPasswordHash = AESKeyUtils.computeId( + AESKeyUtils.generatePasswordBasedKey(repositoryPassword, localRepositoryPasswordSalt) + ); + // a "hash" computed locally is also locally trusted (trivially) + this.validatedLocalRepositoryPasswordHash = this.localRepositoryPasswordHash; + // stores decrypted DEKs; DEKs are reused to encrypt/decrypt multiple independent blobs + this.dekCache = CacheBuilder.builder().setMaximumWeight(DEK_CACHE_WEIGHT).build(); + if (isReadOnly() != delegatedRepository.isReadOnly()) { + throw new RepositoryException( + metadata.name(), + "Unexpected fatal internal error", + new IllegalStateException("The encrypted repository must be read-only iff the delegate repository is read-only") + ); + } + } + + @Override + public RepositoryStats stats() { + return this.delegatedRepository.stats(); + } + + /** + * The repository hook method which populates the snapshot metadata with the salted password hash of the repository on the (master) + * node that starts of the snapshot operation. All the other actions associated with the same snapshot operation will first verify + * that the local repository password checks with the hash from the snapshot metadata. + *

+ * In addition, if the installed license does not comply with the "encrypted snapshots" feature, this method throws an exception, + * which aborts the snapshot operation. + * + * See {@link org.elasticsearch.repositories.Repository#adaptUserMetadata(Map)}. + * + * @param userMetadata the snapshot metadata as received from the calling user + * @return the snapshot metadata containing the salted password hash of the node initializing the snapshot + */ + @Override + public Map adaptUserMetadata(Map userMetadata) { + // because populating the snapshot metadata must be done before the actual snapshot is first initialized, + // we take the opportunity to validate the license and abort if non-compliant + if (false == licenseStateSupplier.get().isAllowed(XPackLicenseState.Feature.ENCRYPTED_SNAPSHOT)) { + throw LicenseUtils.newComplianceException("encrypted snapshots"); + } + Map snapshotUserMetadata = new HashMap<>(); + if (userMetadata != null) { + snapshotUserMetadata.putAll(userMetadata); + } + // fill in the hash of the repository password, which is then checked before every snapshot operation + // (i.e. {@link #snapshotShard} and {@link #finalizeSnapshot}) to ensure that all participating nodes + // in the snapshot operation use the same repository password + snapshotUserMetadata.put(PASSWORD_SALT_USER_METADATA_KEY, localRepositoryPasswordSalt); + snapshotUserMetadata.put(PASSWORD_HASH_USER_METADATA_KEY, localRepositoryPasswordHash); + logger.trace( + "Snapshot metadata for local repository password [{}] and [{}]", + localRepositoryPasswordSalt, + localRepositoryPasswordHash + ); + // do not wrap in Map.of; we have to be able to modify the map (remove the added entries) when finalizing the snapshot + return snapshotUserMetadata; + } + + @Override + public void finalizeSnapshot( + ShardGenerations shardGenerations, + long repositoryStateId, + Metadata clusterMetadata, + SnapshotInfo snapshotInfo, + Version repositoryMetaVersion, + Function stateTransformer, + ActionListener listener + ) { + try { + validateLocalRepositorySecret(snapshotInfo.userMetadata()); + } catch (RepositoryException passwordValidationException) { + listener.onFailure(passwordValidationException); + return; + } finally { + // remove the repository password hash (and salt) from the snapshot metadata so that it is not displayed in the API response + // to the user + snapshotInfo.userMetadata().remove(PASSWORD_HASH_USER_METADATA_KEY); + snapshotInfo.userMetadata().remove(PASSWORD_SALT_USER_METADATA_KEY); + } + super.finalizeSnapshot( + shardGenerations, + repositoryStateId, + clusterMetadata, + snapshotInfo, + repositoryMetaVersion, + stateTransformer, + listener + ); + } + + @Override + public void snapshotShard( + Store store, + MapperService mapperService, + SnapshotId snapshotId, + IndexId indexId, + IndexCommit snapshotIndexCommit, + String shardStateIdentifier, + IndexShardSnapshotStatus snapshotStatus, + Version repositoryMetaVersion, + Map userMetadata, + ActionListener listener + ) { + try { + validateLocalRepositorySecret(userMetadata); + } catch (RepositoryException passwordValidationException) { + listener.onFailure(passwordValidationException); + return; + } + super.snapshotShard( + store, + mapperService, + snapshotId, + indexId, + snapshotIndexCommit, + shardStateIdentifier, + snapshotStatus, + repositoryMetaVersion, + userMetadata, + listener + ); + } + + @Override + protected BlobStore createBlobStore() { + final Supplier> blobStoreDEKGenerator; + if (isReadOnly()) { + // make sure that a read-only repository can't encrypt anything + blobStoreDEKGenerator = () -> { + throw new RepositoryException( + metadata.name(), + "Unexpected fatal internal error", + new IllegalStateException("DEKs are required for encryption but this is a read-only repository") + ); + }; + } else { + blobStoreDEKGenerator = this.dekGenerator; + } + return new EncryptedBlobStore( + delegatedRepository.blobStore(), + delegatedRepository.basePath(), + metadata.name(), + this::generateKEK, + blobStoreDEKGenerator, + dekCache + ); + } + + @Override + public BlobPath basePath() { + // the encrypted repository uses a hardcoded empty base blob path, + // but the base path setting is honored for the delegated repository + return BlobPath.cleanPath(); + } + + @Override + protected void doStart() { + this.delegatedRepository.start(); + super.doStart(); + } + + @Override + protected void doStop() { + super.doStop(); + this.delegatedRepository.stop(); + } + + @Override + protected void doClose() { + super.doClose(); + this.delegatedRepository.close(); + } + + private Supplier> createDEKGenerator() throws GeneralSecurityException { + // DEK and DEK Ids MUST be generated randomly (with independent random instances) + // the rand algo is not pinned so that it goes well with various providers (eg FIPS) + // TODO maybe we can make this a setting for rigurous users + final SecureRandom dekSecureRandom = new SecureRandom(); + final SecureRandom dekIdSecureRandom = new SecureRandom(); + final KeyGenerator dekGenerator = KeyGenerator.getInstance(DATA_ENCRYPTION_SCHEME.split("/")[0]); + dekGenerator.init(AESKeyUtils.KEY_LENGTH_IN_BYTES * Byte.SIZE, dekSecureRandom); + return () -> { + final BytesReference dekId = new BytesArray(UUIDs.randomBase64UUID(dekIdSecureRandom)); + final SecretKey dek = dekGenerator.generateKey(); + logger.debug("Repository [{}] generated new DEK [{}]", metadata.name(), dekId); + return new Tuple<>(dekId, dek); + }; + } + + // pkg-private for tests + Tuple generateKEK(String dekId) { + try { + // we rely on the DEK Id being generated randomly so it can be used as a salt + final SecretKey kek = AESKeyUtils.generatePasswordBasedKey(repositoryPassword, dekId); + final String kekId = AESKeyUtils.computeId(kek); + logger.debug("Repository [{}] computed KEK [{}] for DEK [{}]", metadata.name(), kekId, dekId); + return new Tuple<>(kekId, kek); + } catch (GeneralSecurityException e) { + throw new RepositoryException(metadata.name(), "Failure to generate KEK to wrap the DEK [" + dekId + "]", e); + } + } + + /** + * Called before the shard snapshot and finalize operations, on the data and master nodes. This validates that the repository + * password on the master node that started the snapshot operation is identical to the repository password on the local node. + * + * @param snapshotUserMetadata the snapshot metadata containing the repository password hash to assert + * @throws RepositoryException if the repository password hash on the local node mismatches the master's + */ + private void validateLocalRepositorySecret(Map snapshotUserMetadata) throws RepositoryException { + assert snapshotUserMetadata != null; + assert snapshotUserMetadata.get(PASSWORD_HASH_USER_METADATA_KEY) instanceof String; + final String masterRepositoryPasswordId = (String) snapshotUserMetadata.get(PASSWORD_HASH_USER_METADATA_KEY); + if (false == masterRepositoryPasswordId.equals(validatedLocalRepositoryPasswordHash)) { + assert snapshotUserMetadata.get(PASSWORD_SALT_USER_METADATA_KEY) instanceof String; + final String masterRepositoryPasswordIdSalt = (String) snapshotUserMetadata.get(PASSWORD_SALT_USER_METADATA_KEY); + final String computedRepositoryPasswordId; + try { + computedRepositoryPasswordId = AESKeyUtils.computeId( + AESKeyUtils.generatePasswordBasedKey(repositoryPassword, masterRepositoryPasswordIdSalt) + ); + } catch (Exception e) { + throw new RepositoryException(metadata.name(), "Unexpected fatal internal error", e); + } + if (computedRepositoryPasswordId.equals(masterRepositoryPasswordId)) { + this.validatedLocalRepositoryPasswordHash = computedRepositoryPasswordId; + } else { + throw new RepositoryException( + metadata.name(), + "Repository password mismatch. The local node's repository password, from the keystore setting [" + + EncryptedRepositoryPlugin.ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace( + EncryptedRepositoryPlugin.PASSWORD_NAME_SETTING.get(metadata.settings()) + ).getKey() + + "], is different compared to the elected master node's which started the snapshot operation" + ); + } + } + } + + // pkg-private for tests + static final class EncryptedBlobStore implements BlobStore { + private final BlobStore delegatedBlobStore; + private final BlobPath delegatedBasePath; + private final String repositoryName; + private final Function> getKEKforDEK; + private final Cache dekCache; + private final CheckedSupplier singleUseDEKSupplier; + + EncryptedBlobStore( + BlobStore delegatedBlobStore, + BlobPath delegatedBasePath, + String repositoryName, + Function> getKEKforDEK, + Supplier> dekGenerator, + Cache dekCache + ) { + this.delegatedBlobStore = delegatedBlobStore; + this.delegatedBasePath = delegatedBasePath; + this.repositoryName = repositoryName; + this.getKEKforDEK = getKEKforDEK; + this.dekCache = dekCache; + this.singleUseDEKSupplier = SingleUseKey.createSingleUseKeySupplier(() -> { + Tuple newDEK = dekGenerator.get(); + // store the newly generated DEK before making it available + storeDEK(newDEK.v1().utf8ToString(), newDEK.v2()); + return newDEK; + }); + } + + // pkg-private for tests + SecretKey getDEKById(String dekId) throws IOException { + try { + return dekCache.computeIfAbsent(dekId, ignored -> loadDEK(dekId)); + } catch (ExecutionException e) { + // some exception types are to be expected + if (e.getCause() instanceof IOException) { + throw (IOException) e.getCause(); + } else if (e.getCause() instanceof ElasticsearchException) { + throw (ElasticsearchException) e.getCause(); + } else { + throw new RepositoryException(repositoryName, "Unexpected exception retrieving DEK [" + dekId + "]", e); + } + } + } + + private SecretKey loadDEK(String dekId) throws IOException { + final BlobPath dekBlobPath = delegatedBasePath.add(DEK_ROOT_CONTAINER).add(dekId); + logger.debug("Repository [{}] loading wrapped DEK [{}] from blob path {}", repositoryName, dekId, dekBlobPath); + final BlobContainer dekBlobContainer = delegatedBlobStore.blobContainer(dekBlobPath); + final Tuple kekTuple = getKEKforDEK.apply(dekId); + final String kekId = kekTuple.v1(); + final SecretKey kek = kekTuple.v2(); + logger.trace("Repository [{}] using KEK [{}] to unwrap DEK [{}]", repositoryName, kekId, dekId); + final byte[] encryptedDEKBytes = new byte[AESKeyUtils.WRAPPED_KEY_LENGTH_IN_BYTES]; + try (InputStream encryptedDEKInputStream = dekBlobContainer.readBlob(kekId)) { + final int bytesRead = Streams.readFully(encryptedDEKInputStream, encryptedDEKBytes); + if (bytesRead != AESKeyUtils.WRAPPED_KEY_LENGTH_IN_BYTES) { + throw new RepositoryException( + repositoryName, + "Wrapped DEK [" + dekId + "] has smaller length [" + bytesRead + "] than expected" + ); + } + if (encryptedDEKInputStream.read() != -1) { + throw new RepositoryException(repositoryName, "Wrapped DEK [" + dekId + "] is larger than expected"); + } + } catch (NoSuchFileException e) { + // do NOT throw IOException when the DEK does not exist, as this is a decryption problem, and IOExceptions + // can move the repository in the corrupted state + throw new ElasticsearchException( + "Failure to read and decrypt DEK [" + + dekId + + "] from " + + dekBlobContainer.path() + + ". Most likely the repository password is incorrect, where previous " + + "snapshots have used a different password.", + e + ); + } + logger.trace("Repository [{}] successfully read DEK [{}] from path {} {}", repositoryName, dekId, dekBlobPath, kekId); + try { + final SecretKey dek = AESKeyUtils.unwrap(kek, encryptedDEKBytes); + logger.debug("Repository [{}] successfully loaded DEK [{}] from path {} {}", repositoryName, dekId, dekBlobPath, kekId); + return dek; + } catch (GeneralSecurityException e) { + throw new RepositoryException( + repositoryName, + "Failure to AES unwrap the DEK [" + + dekId + + "]. " + + "Most likely the encryption metadata in the repository has been corrupted", + e + ); + } + } + + // pkg-private for tests + void storeDEK(String dekId, SecretKey dek) throws IOException { + final BlobPath dekBlobPath = delegatedBasePath.add(DEK_ROOT_CONTAINER).add(dekId); + logger.debug("Repository [{}] storing wrapped DEK [{}] under blob path {}", repositoryName, dekId, dekBlobPath); + final BlobContainer dekBlobContainer = delegatedBlobStore.blobContainer(dekBlobPath); + final Tuple kek = getKEKforDEK.apply(dekId); + logger.trace("Repository [{}] using KEK [{}] to wrap DEK [{}]", repositoryName, kek.v1(), dekId); + final byte[] encryptedDEKBytes; + try { + encryptedDEKBytes = AESKeyUtils.wrap(kek.v2(), dek); + if (encryptedDEKBytes.length != AESKeyUtils.WRAPPED_KEY_LENGTH_IN_BYTES) { + throw new RepositoryException( + repositoryName, + "Wrapped DEK [" + dekId + "] has unexpected length [" + encryptedDEKBytes.length + "]" + ); + } + } catch (GeneralSecurityException e) { + // throw unchecked ElasticsearchException; IOExceptions are interpreted differently and can move the repository in the + // corrupted state + throw new RepositoryException(repositoryName, "Failure to AES wrap the DEK [" + dekId + "]", e); + } + logger.trace("Repository [{}] successfully wrapped DEK [{}]", repositoryName, dekId); + dekBlobContainer.writeBlobAtomic(kek.v1(), new BytesArray(encryptedDEKBytes), true); + logger.debug("Repository [{}] successfully stored DEK [{}] under path {} {}", repositoryName, dekId, dekBlobPath, kek.v1()); + } + + @Override + public BlobContainer blobContainer(BlobPath path) { + final Iterator pathIterator = path.iterator(); + BlobPath delegatedBlobContainerPath = delegatedBasePath; + while (pathIterator.hasNext()) { + delegatedBlobContainerPath = delegatedBlobContainerPath.add(pathIterator.next()); + } + final BlobContainer delegatedBlobContainer = delegatedBlobStore.blobContainer(delegatedBlobContainerPath); + return new EncryptedBlobContainer(path, repositoryName, delegatedBlobContainer, singleUseDEKSupplier, this::getDEKById); + } + + @Override + public void close() { + // do NOT close delegatedBlobStore; it will be closed when the inner delegatedRepository is closed + } + } + + private static final class EncryptedBlobContainer extends AbstractBlobContainer { + private final String repositoryName; + private final BlobContainer delegatedBlobContainer; + // supplier for the DEK used for encryption (snapshot) + private final CheckedSupplier singleUseDEKSupplier; + // retrieves the DEK required for decryption (restore) + private final CheckedFunction getDEKById; + + EncryptedBlobContainer( + BlobPath path, // this path contains the {@code EncryptedRepository#basePath} which, importantly, is empty + String repositoryName, + BlobContainer delegatedBlobContainer, + CheckedSupplier singleUseDEKSupplier, + CheckedFunction getDEKById + ) { + super(path); + this.repositoryName = repositoryName; + final String rootPathElement = path.iterator().hasNext() ? path.iterator().next() : null; + if (DEK_ROOT_CONTAINER.equals(rootPathElement)) { + throw new RepositoryException(repositoryName, "Cannot descend into the DEK blob container " + path); + } + this.delegatedBlobContainer = delegatedBlobContainer; + this.singleUseDEKSupplier = singleUseDEKSupplier; + this.getDEKById = getDEKById; + } + + @Override + public boolean blobExists(String blobName) throws IOException { + return delegatedBlobContainer.blobExists(blobName); + } + + /** + * Returns a new {@link InputStream} for the given {@code blobName} that can be used to read the contents of the blob. + * The returned {@code InputStream} transparently handles the decryption of the blob contents, by first working out + * the blob name of the associated DEK id, reading and decrypting the DEK (given the repository password, unless the DEK is + * already cached because it had been used for other blobs before), and lastly reading and decrypting the data blob, + * in a streaming fashion, by employing the {@link DecryptionPacketsInputStream}. + * The {@code DecryptionPacketsInputStream} does not return un-authenticated data. + * + * @param blobName The name of the blob to get an {@link InputStream} for. + */ + @Override + public InputStream readBlob(String blobName) throws IOException { + // This MIGHT require two concurrent readBlob connections if the DEK is not already in the cache and if the encrypted blob + // is large enough so that the underlying network library keeps the connection open after reading the prepended DEK ID. + // Arguably this is a problem only under lab conditions, when the storage service is saturated only by the first read + // connection of the pair, so that the second read connection (for the DEK) can not be fulfilled. + // In this case the second connection will time-out which will trigger the closing of the first one, therefore + // allowing other pair connections to complete. + // In this situation the restore process should slowly make headway, albeit under read-timeout exceptions + final InputStream encryptedDataInputStream = delegatedBlobContainer.readBlob(blobName); + try { + // read the DEK Id (fixed length) which is prepended to the encrypted blob + final byte[] dekIdBytes = new byte[DEK_ID_LENGTH]; + final int bytesRead = Streams.readFully(encryptedDataInputStream, dekIdBytes); + if (bytesRead != DEK_ID_LENGTH) { + throw new RepositoryException(repositoryName, "The encrypted blob [" + blobName + "] is too small [" + bytesRead + "]"); + } + final String dekId = new String(dekIdBytes, StandardCharsets.UTF_8); + // might open a connection to read and decrypt the DEK, but most likely it will be served from cache + final SecretKey dek = getDEKById.apply(dekId); + // read and decrypt the rest of the blob + return new DecryptionPacketsInputStream(encryptedDataInputStream, dek, PACKET_LENGTH_IN_BYTES); + } catch (Exception e) { + try { + encryptedDataInputStream.close(); + } catch (IOException closeEx) { + e.addSuppressed(closeEx); + } + throw e; + } + } + + @Override + public InputStream readBlob(String blobName, long position, long length) throws IOException { + throw new UnsupportedOperationException("Not yet implemented"); + } + + /** + * Reads the blob content from the input stream and writes it to the container in a new blob with the given name. + * If {@code failIfAlreadyExists} is {@code true} and a blob with the same name already exists, the write operation will fail; + * otherwise, if {@code failIfAlreadyExists} is {@code false} the blob is overwritten. + * The contents are encrypted in a streaming fashion. The DEK (encryption key) is randomly generated and reused for encrypting + * subsequent blobs such that the same IV is not reused together with the same key. + * The DEK encryption key is separately stored in a different blob, which is encrypted with the repository key. + * + * @param blobName + * The name of the blob to write the contents of the input stream to. + * @param inputStream + * The input stream from which to retrieve the bytes to write to the blob. + * @param blobSize + * The size of the blob to be written, in bytes. The actual number of bytes written to the storage service is larger + * because of encryption and authentication overhead. It is implementation dependent whether this value is used + * in writing the blob to the repository. + * @param failIfAlreadyExists + * whether to throw a FileAlreadyExistsException if the given blob already exists + */ + @Override + public void writeBlob(String blobName, InputStream inputStream, long blobSize, boolean failIfAlreadyExists) throws IOException { + // reuse, but possibly generate and store a new DEK + final SingleUseKey singleUseNonceAndDEK = singleUseDEKSupplier.get(); + final BytesReference dekIdBytes = singleUseNonceAndDEK.getKeyId(); + if (dekIdBytes.length() != DEK_ID_LENGTH) { + throw new RepositoryException( + repositoryName, + "Unexpected fatal internal error", + new IllegalStateException("Unexpected DEK Id length [" + dekIdBytes.length() + "]") + ); + } + final long encryptedBlobSize = getEncryptedBlobByteLength(blobSize); + try ( + InputStream encryptedInputStream = ChainingInputStream.chain( + dekIdBytes.streamInput(), + new EncryptionPacketsInputStream( + inputStream, + singleUseNonceAndDEK.getKey(), + singleUseNonceAndDEK.getNonce(), + PACKET_LENGTH_IN_BYTES + ) + ) + ) { + delegatedBlobContainer.writeBlob(blobName, encryptedInputStream, encryptedBlobSize, failIfAlreadyExists); + } + } + + @Override + public void writeBlobAtomic(String blobName, BytesReference bytes, boolean failIfAlreadyExists) throws IOException { + // the encrypted repository does not offer an alternative implementation for atomic writes + // fallback to regular write + writeBlob(blobName, bytes, failIfAlreadyExists); + } + + @Override + public DeleteResult delete() throws IOException { + return delegatedBlobContainer.delete(); + } + + @Override + public void deleteBlobsIgnoringIfNotExists(List blobNames) throws IOException { + delegatedBlobContainer.deleteBlobsIgnoringIfNotExists(blobNames); + } + + @Override + public Map listBlobs() throws IOException { + return delegatedBlobContainer.listBlobs(); + } + + @Override + public Map listBlobsByPrefix(String blobNamePrefix) throws IOException { + return delegatedBlobContainer.listBlobsByPrefix(blobNamePrefix); + } + + @Override + public Map children() throws IOException { + final Map childEncryptedBlobContainers = delegatedBlobContainer.children(); + final Map resultBuilder = new HashMap<>(childEncryptedBlobContainers.size()); + for (Map.Entry childBlobContainer : childEncryptedBlobContainers.entrySet()) { + if (childBlobContainer.getKey().equals(DEK_ROOT_CONTAINER) && false == path().iterator().hasNext()) { + // do not descend into the DEK blob container + continue; + } + // get an encrypted blob container for each child + // Note that the encryption metadata blob container might be missing + resultBuilder.put( + childBlobContainer.getKey(), + new EncryptedBlobContainer( + path().add(childBlobContainer.getKey()), + repositoryName, + childBlobContainer.getValue(), + singleUseDEKSupplier, + getDEKById + ) + ); + } + return resultBuilder; + } + } +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryPlugin.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryPlugin.java new file mode 100644 index 0000000000000..03ff9a59d10da --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryPlugin.java @@ -0,0 +1,199 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.Build; +import org.elasticsearch.cluster.metadata.RepositoryMetadata; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureSetting; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.env.Environment; +import org.elasticsearch.indices.recovery.RecoverySettings; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.RepositoryPlugin; +import org.elasticsearch.repositories.Repository; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; +import org.elasticsearch.xpack.core.XPackPlugin; + +import java.security.GeneralSecurityException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Supplier; + +public class EncryptedRepositoryPlugin extends Plugin implements RepositoryPlugin { + + private static final Boolean ENCRYPTED_REPOSITORY_FEATURE_FLAG_REGISTERED; + static { + final String property = System.getProperty("es.encrypted_repository_feature_flag_registered"); + if (Build.CURRENT.isSnapshot() && property != null) { + throw new IllegalArgumentException("es.encrypted_repository_feature_flag_registered is only supported in non-snapshot builds"); + } + if ("true".equals(property)) { + ENCRYPTED_REPOSITORY_FEATURE_FLAG_REGISTERED = true; + } else if ("false".equals(property)) { + ENCRYPTED_REPOSITORY_FEATURE_FLAG_REGISTERED = false; + } else if (property == null) { + ENCRYPTED_REPOSITORY_FEATURE_FLAG_REGISTERED = null; + } else { + throw new IllegalArgumentException( + "expected es.encrypted_repository_feature_flag_registered to be unset or [true|false] but was [" + property + "]" + ); + } + } + + static final Logger logger = LogManager.getLogger(EncryptedRepositoryPlugin.class); + static final String REPOSITORY_TYPE_NAME = "encrypted"; + // TODO add at least hdfs, and investigate supporting all `BlobStoreRepository` implementations + static final List SUPPORTED_ENCRYPTED_TYPE_NAMES = Arrays.asList("fs", "gcs", "azure", "s3"); + static final Setting.AffixSetting ENCRYPTION_PASSWORD_SETTING = Setting.affixKeySetting( + "repository.encrypted.", + "password", + key -> SecureSetting.secureString(key, null) + ); + static final Setting DELEGATE_TYPE_SETTING = Setting.simpleString("delegate_type", ""); + static final Setting PASSWORD_NAME_SETTING = Setting.simpleString("password_name", ""); + + // "protected" because it is overloaded for tests + protected XPackLicenseState getLicenseState() { + return XPackPlugin.getSharedLicenseState(); + } + + @Override + public List> getSettings() { + return Collections.singletonList(ENCRYPTION_PASSWORD_SETTING); + } + + @Override + public Map getRepositories( + Environment env, + NamedXContentRegistry registry, + ClusterService clusterService, + BigArrays bigArrays, + RecoverySettings recoverySettings + ) { + // load all the passwords from the keystore in memory because the keystore is not readable when the repository is created + final Map repositoryPasswordsMapBuilder = new HashMap<>(); + for (String passwordName : ENCRYPTION_PASSWORD_SETTING.getNamespaces(env.settings())) { + Setting passwordSetting = ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(passwordName); + repositoryPasswordsMapBuilder.put(passwordName, passwordSetting.get(env.settings())); + logger.debug("Loaded repository password [{}] from the node keystore", passwordName); + } + final Map repositoryPasswordsMap = repositoryPasswordsMapBuilder; + + if (false == Build.CURRENT.isSnapshot() + && (ENCRYPTED_REPOSITORY_FEATURE_FLAG_REGISTERED == null || ENCRYPTED_REPOSITORY_FEATURE_FLAG_REGISTERED == false)) { + return Collections.emptyMap(); + } + + return Collections.singletonMap(REPOSITORY_TYPE_NAME, new Repository.Factory() { + + @Override + public Repository create(RepositoryMetadata metadata) { + throw new UnsupportedOperationException(); + } + + @Override + public Repository create(RepositoryMetadata metadata, Function typeLookup) throws Exception { + final String delegateType = DELEGATE_TYPE_SETTING.get(metadata.settings()); + if (Strings.hasLength(delegateType) == false) { + throw new IllegalArgumentException("Repository setting [" + DELEGATE_TYPE_SETTING.getKey() + "] must be set"); + } + if (REPOSITORY_TYPE_NAME.equals(delegateType)) { + throw new IllegalArgumentException( + "Cannot encrypt an already encrypted repository. [" + + DELEGATE_TYPE_SETTING.getKey() + + "] must not be equal to [" + + REPOSITORY_TYPE_NAME + + "]" + ); + } + final Repository.Factory factory = typeLookup.apply(delegateType); + if (null == factory || false == SUPPORTED_ENCRYPTED_TYPE_NAMES.contains(delegateType)) { + throw new IllegalArgumentException( + "Unsupported delegate repository type [" + delegateType + "] for setting [" + DELEGATE_TYPE_SETTING.getKey() + "]" + ); + } + final String repositoryPasswordName = PASSWORD_NAME_SETTING.get(metadata.settings()); + if (Strings.hasLength(repositoryPasswordName) == false) { + throw new IllegalArgumentException("Repository setting [" + PASSWORD_NAME_SETTING.getKey() + "] must be set"); + } + final SecureString repositoryPassword = repositoryPasswordsMap.get(repositoryPasswordName); + if (repositoryPassword == null) { + throw new IllegalArgumentException( + "Secure setting [" + + ENCRYPTION_PASSWORD_SETTING.getConcreteSettingForNamespace(repositoryPasswordName).getKey() + + "] must be set" + ); + } + final Repository delegatedRepository = factory.create( + new RepositoryMetadata(metadata.name(), delegateType, metadata.settings()) + ); + if (false == (delegatedRepository instanceof BlobStoreRepository) || delegatedRepository instanceof EncryptedRepository) { + throw new IllegalArgumentException("Unsupported delegate repository type [" + DELEGATE_TYPE_SETTING.getKey() + "]"); + } + if (false == getLicenseState().checkFeature(XPackLicenseState.Feature.ENCRYPTED_SNAPSHOT)) { + logger.warn( + new ParameterizedMessage( + "Encrypted snapshots are not allowed for the currently installed license [{}]." + + " Snapshots to the [{}] encrypted repository are not permitted." + + " All the other operations, including restore, work without restrictions.", + getLicenseState().getOperationMode().description(), + metadata.name() + ), + LicenseUtils.newComplianceException("encrypted snapshots") + ); + } + return createEncryptedRepository( + metadata, + registry, + clusterService, + bigArrays, + recoverySettings, + (BlobStoreRepository) delegatedRepository, + () -> getLicenseState(), + repositoryPassword + ); + } + }); + } + + // protected for tests + protected EncryptedRepository createEncryptedRepository( + RepositoryMetadata metadata, + NamedXContentRegistry registry, + ClusterService clusterService, + BigArrays bigArrays, + RecoverySettings recoverySettings, + BlobStoreRepository delegatedRepository, + Supplier licenseStateSupplier, + SecureString repoPassword + ) throws GeneralSecurityException { + return new EncryptedRepository( + metadata, + registry, + clusterService, + bigArrays, + recoverySettings, + delegatedRepository, + licenseStateSupplier, + repoPassword + ); + } +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptionPacketsInputStream.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptionPacketsInputStream.java new file mode 100644 index 0000000000000..a08ea4216c94d --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/EncryptionPacketsInputStream.java @@ -0,0 +1,198 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.core.internal.io.IOUtils; + +import javax.crypto.Cipher; +import javax.crypto.CipherInputStream; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.SecretKey; +import javax.crypto.spec.GCMParameterSpec; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.SequenceInputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.Objects; + +/** + * An {@code EncryptionPacketsInputStream} wraps another input stream and encrypts its contents. + * The method of encryption is AES/GCM/NoPadding, which is a type of authenticated encryption. + * The encryption works packet wise, i.e. the stream is segmented into fixed-size byte packets + * which are separately encrypted using a unique {@link Cipher}. As an exception, only the last + * packet will have a different size, possibly zero. Note that the encrypted packets are + * larger compared to the plaintext packets, because they contain a 16 byte length trailing + * authentication tag. The resulting encrypted and authenticated packets are assembled back into + * the resulting stream. + *

+ * The packets are encrypted using the same {@link SecretKey} but using a different initialization + * vector. The IV of each packet is 12 bytes wide and is comprised of a 4-byte integer {@code nonce}, + * the same for every packet in the stream, and a monotonically increasing 8-byte integer counter. + * The caller must assure that the same {@code nonce} is not reused for other encrypted streams + * using the same {@code secretKey}. The counter from the IV identifies the position of the packet + * in the encrypted stream, so that packets cannot be reordered without breaking the decryption. + * When assembling the encrypted stream, the IV is prepended to the corresponding packet's ciphertext. + *

+ * The packet length is preferably a large multiple (typically 128) of the AES block size (128 bytes), + * but any positive integer value smaller than {@link EncryptedRepository#MAX_PACKET_LENGTH_IN_BYTES} + * is valid. A larger packet length incurs smaller relative size overhead because the 12 byte wide IV + * and the 16 byte wide authentication tag are constant no matter the packet length. A larger packet + * length also exposes more opportunities for the JIT compilation of the AES encryption loop. But + * {@code mark} will buffer up to packet length bytes, and, more importantly, decryption might + * need to allocate a memory buffer the size of the packet in order to assure that no un-authenticated + * decrypted ciphertext is returned. The decryption procedure is the primary factor that limits the + * packet length. + *

+ * This input stream supports the {@code mark} and {@code reset} operations, but only if the wrapped + * stream supports them as well. A {@code mark} call will trigger the memory buffering of the current + * packet and will also trigger a {@code mark} call on the wrapped input stream on the next + * packet boundary. Upon a {@code reset} call, the buffered packet will be replayed and new packets + * will be generated starting from the marked packet boundary on the wrapped stream. + *

+ * The {@code close} call will close the encryption input stream and any subsequent {@code read}, + * {@code skip}, {@code available} and {@code reset} calls will throw {@code IOException}s. + *

+ * This is NOT thread-safe, multiple threads sharing a single instance must synchronize access. + * + * @see DecryptionPacketsInputStream + */ +public final class EncryptionPacketsInputStream extends ChainingInputStream { + + private final SecretKey secretKey; + private final int packetLength; + private final ByteBuffer packetIv; + private final int encryptedPacketLength; + + final InputStream source; // package-protected for tests + long counter; // package-protected for tests + Long markCounter; // package-protected for tests + int markSourceOnNextPacket; // package-protected for tests + + /** + * Computes and returns the length of the ciphertext given the {@code plaintextLength} and the {@code packetLength} + * used during encryption. + * The plaintext is segmented into packets of equal {@code packetLength} length, with the exception of the last + * packet which is shorter and can have a length of {@code 0}. Encryption is packet-wise and is 1:1, with no padding. + * But each encrypted packet is prepended by the Initilization Vector and appended the Authentication Tag, including + * the last packet, so when pieced together will amount to a longer resulting ciphertext. + * + * @see DecryptionPacketsInputStream#getDecryptionLength(long, int) + */ + public static long getEncryptionLength(long plaintextLength, int packetLength) { + return plaintextLength + (plaintextLength / packetLength + 1) * (EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES + + EncryptedRepository.GCM_IV_LENGTH_IN_BYTES); + } + + public EncryptionPacketsInputStream(InputStream source, SecretKey secretKey, int nonce, int packetLength) { + this.source = Objects.requireNonNull(source); + this.secretKey = Objects.requireNonNull(secretKey); + if (packetLength <= 0 || packetLength >= EncryptedRepository.MAX_PACKET_LENGTH_IN_BYTES) { + throw new IllegalArgumentException("Invalid packet length [" + packetLength + "]"); + } + this.packetLength = packetLength; + this.packetIv = ByteBuffer.allocate(EncryptedRepository.GCM_IV_LENGTH_IN_BYTES).order(ByteOrder.LITTLE_ENDIAN); + // nonce takes the first 4 bytes of the IV + this.packetIv.putInt(0, nonce); + this.encryptedPacketLength = packetLength + EncryptedRepository.GCM_IV_LENGTH_IN_BYTES + + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; + this.counter = EncryptedRepository.PACKET_START_COUNTER; + this.markCounter = null; + this.markSourceOnNextPacket = -1; + } + + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + // the last packet input stream is the only one shorter than encryptedPacketLength + if (currentComponentIn != null && ((CountingInputStream) currentComponentIn).getCount() < encryptedPacketLength) { + // there are no more packets + return null; + } + // If the enclosing stream has a mark set, + // then apply it to the source input stream when we reach a packet boundary + if (markSourceOnNextPacket != -1) { + source.mark(markSourceOnNextPacket); + markSourceOnNextPacket = -1; + } + // create the new packet + InputStream encryptionInputStream = new PrefixInputStream(source, packetLength, false); + // the counter takes up the last 8 bytes of the packet IV (12 byte wide) + // the first 4 bytes are used by the nonce (which is the same for every packet IV) + packetIv.putLong(Integer.BYTES, counter++); + // counter wrap around + if (counter == EncryptedRepository.PACKET_START_COUNTER) { + throw new IOException("Maximum packet count limit exceeded"); + } + Cipher packetCipher = getPacketEncryptionCipher(secretKey, packetIv.array()); + encryptionInputStream = new CipherInputStream(encryptionInputStream, packetCipher); + encryptionInputStream = new SequenceInputStream(new ByteArrayInputStream(packetIv.array()), encryptionInputStream); + encryptionInputStream = new BufferOnMarkInputStream(encryptionInputStream, encryptedPacketLength); + return new CountingInputStream(encryptionInputStream, false); + } + + // remove after https://github.com/elastic/elasticsearch/pull/66769 is merged in + @Override + public int available() throws IOException { + return 0; + } + + @Override + public boolean markSupported() { + return source.markSupported(); + } + + @Override + public void mark(int readlimit) { + if (markSupported()) { + if (readlimit <= 0) { + throw new IllegalArgumentException("Mark readlimit must be a positive integer"); + } + // handles the packet-wise part of the marking operation + super.mark(encryptedPacketLength); + // saves the counter used to generate packet IVs + markCounter = counter; + // stores the flag used to mark the source input stream at packet boundary + markSourceOnNextPacket = readlimit; + } + } + + @Override + public void reset() throws IOException { + if (false == markSupported()) { + throw new IOException("Mark/reset not supported"); + } + if (markCounter == null) { + throw new IOException("Mark no set"); + } + super.reset(); + counter = markCounter; + if (markSourceOnNextPacket == -1) { + source.reset(); + } + } + + @Override + public void close() throws IOException { + IOUtils.close(super::close, source); + } + + private static Cipher getPacketEncryptionCipher(SecretKey secretKey, byte[] packetIv) throws IOException { + GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES * Byte.SIZE, packetIv); + try { + Cipher packetCipher = Cipher.getInstance(EncryptedRepository.DATA_ENCRYPTION_SCHEME); + packetCipher.init(Cipher.ENCRYPT_MODE, secretKey, gcmParameterSpec); + return packetCipher; + } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | InvalidAlgorithmParameterException e) { + throw new IOException(e); + } + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/PrefixInputStream.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/PrefixInputStream.java new file mode 100644 index 0000000000000..8583cb6ab37ab --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/PrefixInputStream.java @@ -0,0 +1,155 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Locale; + +/** + * A {@code PrefixInputStream} wraps another input stream and exposes + * only the first bytes of it. Reading from the wrapping + * {@code PrefixInputStream} consumes the underlying stream. The stream + * is exhausted when {@code prefixLength} bytes have been read, or the underlying + * stream is exhausted before that. + *

+ * Only if the {@code closeSource} constructor argument is {@code true}, the + * closing of this stream will also close the underlying input stream. + * Any subsequent {@code read}, {@code skip} and {@code available} calls + * will throw {@code IOException}s. + */ +public final class PrefixInputStream extends InputStream { + + /** + * The underlying stream of which only a prefix is returned + */ + private final InputStream source; + /** + * The length in bytes of the prefix. + * This is the maximum number of bytes that can be read from this stream, + * but fewer bytes can be read if the wrapped source stream itself contains fewer bytes + */ + private final int prefixLength; + /** + * The current count of bytes read from this stream. + * This starts of as {@code 0} and is always smaller or equal to {@code prefixLength}. + */ + private int count; + /** + * whether closing this stream must also close the underlying stream + */ + private boolean closeSource; + /** + * flag signalling if this stream has been closed + */ + private boolean closed; + + public PrefixInputStream(InputStream source, int prefixLength, boolean closeSource) { + if (prefixLength < 0) { + throw new IllegalArgumentException("The prefixLength constructor argument must be a positive integer"); + } + this.source = source; + this.prefixLength = prefixLength; + this.count = 0; + this.closeSource = closeSource; + this.closed = false; + } + + @Override + public int read() throws IOException { + ensureOpen(); + if (remainingPrefixByteCount() <= 0) { + return -1; + } + int byteVal = source.read(); + if (byteVal == -1) { + return -1; + } + count++; + return byteVal; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + ensureOpen(); + // this is `Objects#checkFromIndexSize(off, len, b.length)` from Java 9 + if ((b.length | off | len) < 0 || len > b.length - off) { + throw new IndexOutOfBoundsException( + String.format(Locale.ROOT, "Range [%d, % 0; + long bytesSkipped = source.skip(bytesToSkip); + count += bytesSkipped; + return bytesSkipped; + } + + @Override + public int available() throws IOException { + ensureOpen(); + return Math.min(remainingPrefixByteCount(), source.available()); + } + + @Override + public boolean markSupported() { + return false; + } + + @Override + public void mark(int readlimit) { + // mark and reset are not supported + } + + @Override + public void reset() throws IOException { + throw new IOException("mark/reset not supported"); + } + + @Override + public void close() throws IOException { + if (closed) { + return; + } + closed = true; + if (closeSource) { + source.close(); + } + } + + private int remainingPrefixByteCount() { + return prefixLength - count; + } + + private void ensureOpen() throws IOException { + if (closed) { + throw new IOException("Stream has been closed"); + } + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/SingleUseKey.java b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/SingleUseKey.java new file mode 100644 index 0000000000000..fe4729e8acec1 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/java/org/elasticsearch/repositories/encrypted/SingleUseKey.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.collect.Tuple; + +import javax.crypto.SecretKey; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Container class for a {@code SecretKey} with a unique identifier, and a 4-byte wide {@code Integer} nonce, that can be used for a + * single encryption operation. Use {@link #createSingleUseKeySupplier(CheckedSupplier)} to obtain a {@code Supplier} that returns + * a new {@link SingleUseKey} instance on every invocation. The number of unique {@code SecretKey}s (and their associated identifiers) + * generated is minimized and, at the same time, ensuring that a given {@code nonce} is not reused with the same key. + */ +final class SingleUseKey { + private static final Logger logger = LogManager.getLogger(SingleUseKey.class); + static final int MIN_NONCE = Integer.MIN_VALUE; + static final int MAX_NONCE = Integer.MAX_VALUE; + private static final int MAX_ATTEMPTS = 9; + private static final SingleUseKey EXPIRED_KEY = new SingleUseKey(null, null, MAX_NONCE); + + private final BytesReference keyId; + private final SecretKey key; + private final int nonce; + + // for tests use only! + SingleUseKey(BytesReference KeyId, SecretKey Key, int nonce) { + this.keyId = KeyId; + this.key = Key; + this.nonce = nonce; + } + + public BytesReference getKeyId() { + return keyId; + } + + public SecretKey getKey() { + return key; + } + + public int getNonce() { + return nonce; + } + + /** + * Returns a {@code CheckedSupplier} of {@code SingleUseKey}s so that no two instances contain the same key and nonce pair. + * The current implementation increments the {@code nonce} while keeping the key constant, until the {@code nonce} space + * is exhausted, at which moment a new key is generated and the {@code nonce} is reset back. + * + * @param keyGenerator supplier for the key and the key id + */ + static CheckedSupplier createSingleUseKeySupplier( + CheckedSupplier, T> keyGenerator + ) { + final AtomicReference keyCurrentlyInUse = new AtomicReference<>(EXPIRED_KEY); + return internalSingleUseKeySupplier(keyGenerator, keyCurrentlyInUse); + } + + // for tests use only, the {@code keyCurrentlyInUse} must not be exposed to caller code + static CheckedSupplier internalSingleUseKeySupplier( + CheckedSupplier, T> keyGenerator, + AtomicReference keyCurrentlyInUse + ) { + final Object lock = new Object(); + return () -> { + for (int attemptNo = 0; attemptNo < MAX_ATTEMPTS; attemptNo++) { + final SingleUseKey nonceAndKey = keyCurrentlyInUse.getAndUpdate( + prev -> prev.nonce < MAX_NONCE ? new SingleUseKey(prev.keyId, prev.key, prev.nonce + 1) : EXPIRED_KEY + ); + if (nonceAndKey.nonce < MAX_NONCE) { + // this is the commonly used code path, where just the nonce is incremented + logger.trace( + () -> new ParameterizedMessage("Key with id [{}] reused with nonce [{}]", nonceAndKey.keyId, nonceAndKey.nonce) + ); + return nonceAndKey; + } else { + // this is the infrequent code path, where a new key is generated and the nonce is reset back + logger.trace( + () -> new ParameterizedMessage("Try to generate a new key to replace the key with id [{}]", nonceAndKey.keyId) + ); + synchronized (lock) { + if (keyCurrentlyInUse.get().nonce == MAX_NONCE) { + final Tuple newKey = keyGenerator.get(); + logger.debug(() -> new ParameterizedMessage("New key with id [{}] has been generated", newKey.v1())); + keyCurrentlyInUse.set(new SingleUseKey(newKey.v1(), newKey.v2(), MIN_NONCE)); + } + } + } + } + throw new IllegalStateException("Failure to generate new key"); + }; + } +} diff --git a/x-pack/plugin/repository-encrypted/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/repository-encrypted/src/main/plugin-metadata/plugin-security.policy new file mode 100644 index 0000000000000..7f75e2af67c6e --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/main/plugin-metadata/plugin-security.policy @@ -0,0 +1,8 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +grant { +}; diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/AESKeyUtilsTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/AESKeyUtilsTests.java new file mode 100644 index 0000000000000..e1b8beda7247a --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/AESKeyUtilsTests.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.test.ESTestCase; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.security.InvalidKeyException; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +public class AESKeyUtilsTests extends ESTestCase { + + public void testWrapUnwrap() throws Exception { + byte[] keyToWrapBytes = randomByteArrayOfLength(AESKeyUtils.KEY_LENGTH_IN_BYTES); + SecretKey keyToWrap = new SecretKeySpec(keyToWrapBytes, "AES"); + byte[] wrappingKeyBytes = randomByteArrayOfLength(AESKeyUtils.KEY_LENGTH_IN_BYTES); + SecretKey wrappingKey = new SecretKeySpec(wrappingKeyBytes, "AES"); + byte[] wrappedKey = AESKeyUtils.wrap(wrappingKey, keyToWrap); + assertThat(wrappedKey.length, equalTo(AESKeyUtils.WRAPPED_KEY_LENGTH_IN_BYTES)); + SecretKey unwrappedKey = AESKeyUtils.unwrap(wrappingKey, wrappedKey); + assertThat(unwrappedKey, equalTo(keyToWrap)); + } + + public void testComputeId() throws Exception { + byte[] key1Bytes = randomByteArrayOfLength(AESKeyUtils.KEY_LENGTH_IN_BYTES); + SecretKey key1 = new SecretKeySpec(key1Bytes, "AES"); + byte[] key2Bytes = randomByteArrayOfLength(AESKeyUtils.KEY_LENGTH_IN_BYTES); + SecretKey key2 = new SecretKeySpec(key2Bytes, "AES"); + assertThat(AESKeyUtils.computeId(key1), not(equalTo(AESKeyUtils.computeId(key2)))); + assertThat(AESKeyUtils.computeId(key1), equalTo(AESKeyUtils.computeId(key1))); + assertThat(AESKeyUtils.computeId(key2), equalTo(AESKeyUtils.computeId(key2))); + } + + public void testFailedWrapUnwrap() throws Exception { + byte[] toWrapBytes = new byte[] { 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7 }; + SecretKey keyToWrap = new SecretKeySpec(toWrapBytes, "AES"); + byte[] wrapBytes = new byte[] { 0, 0, 0, 0, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 0, 0, 0, 0 }; + SecretKey wrappingKey = new SecretKeySpec(wrapBytes, "AES"); + byte[] wrappedKey = AESKeyUtils.wrap(wrappingKey, keyToWrap); + for (int i = 0; i < wrappedKey.length; i++) { + wrappedKey[i] ^= 0xFFFFFFFF; + expectThrows(InvalidKeyException.class, () -> AESKeyUtils.unwrap(wrappingKey, wrappedKey)); + wrappedKey[i] ^= 0xFFFFFFFF; + } + } +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/BufferOnMarkInputStreamTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/BufferOnMarkInputStreamTests.java new file mode 100644 index 0000000000000..fd221e7d86ce4 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/BufferOnMarkInputStreamTests.java @@ -0,0 +1,854 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.BeforeClass; + +import java.io.ByteArrayInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.elasticsearch.repositories.encrypted.EncryptionPacketsInputStreamTests.readNBytes; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class BufferOnMarkInputStreamTests extends ESTestCase { + + private static byte[] testArray; + + @BeforeClass + static void createTestArray() throws Exception { + testArray = new byte[128]; + for (int i = 0; i < testArray.length; i++) { + testArray[i] = (byte) i; + } + } + + public void testResetWithoutMarkFails() throws Exception { + Tuple mockSourceTuple = getMockInfiniteInputStream(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), 1 + Randomness.get().nextInt(1024)); + // maybe read some bytes + readNBytes(test, randomFrom(0, randomInt(31))); + IOException e = expectThrows(IOException.class, () -> { test.reset(); }); + assertThat(e.getMessage(), Matchers.is("Mark not called or has been invalidated")); + } + + public void testMarkAndBufferReadLimitsCheck() throws Exception { + Tuple mockSourceTuple = getMockInfiniteInputStream(); + int bufferSize = randomIntBetween(1, 1024); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + assertThat(test.getMaxMarkReadlimit(), Matchers.is(bufferSize)); + // maybe read some bytes + readNBytes(test, randomFrom(0, randomInt(32))); + int wrongLargeReadLimit = bufferSize + randomIntBetween(1, 8); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> { test.mark(wrongLargeReadLimit); }); + assertThat( + e.getMessage(), + Matchers.is("Readlimit value [" + wrongLargeReadLimit + "] exceeds the maximum value of [" + bufferSize + "]") + ); + e = expectThrows(IllegalArgumentException.class, () -> { test.mark(-1 - randomInt(1)); }); + assertThat(e.getMessage(), Matchers.containsString("cannot be negative")); + e = expectThrows(IllegalArgumentException.class, () -> { new BufferOnMarkInputStream(mock(InputStream.class), 0 - randomInt(1)); }); + assertThat(e.getMessage(), Matchers.is("The buffersize constructor argument must be a strictly positive value")); + } + + public void testCloseRejectsSuccessiveCalls() throws Exception { + int bufferSize = 3 + Randomness.get().nextInt(128); + Tuple mockSourceTuple = getMockInfiniteInputStream(); + AtomicInteger bytesRead = mockSourceTuple.v1(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + // maybe read some bytes + readNBytes(test, randomFrom(0, Randomness.get().nextInt(32))); + test.close(); + int bytesReadBefore = bytesRead.get(); + IOException e = expectThrows(IOException.class, () -> { test.read(); }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + e = expectThrows(IOException.class, () -> { + byte[] b = new byte[1 + Randomness.get().nextInt(32)]; + test.read(b, 0, 1 + Randomness.get().nextInt(b.length)); + }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + e = expectThrows(IOException.class, () -> { test.skip(1 + Randomness.get().nextInt(32)); }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + e = expectThrows(IOException.class, () -> { test.available(); }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + e = expectThrows(IOException.class, () -> { test.reset(); }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + int bytesReadAfter = bytesRead.get(); + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(0)); + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + } + + public void testBufferingUponMark() throws Exception { + int bufferSize = randomIntBetween(3, 128); + Tuple mockSourceTuple = getMockInfiniteInputStream(); + AtomicInteger bytesRead = mockSourceTuple.v1(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + // read without mark, should be a simple pass-through with the same byte count + int bytesReadBefore = bytesRead.get(); + assertThat(test.read(), Matchers.not(-1)); + int bytesReadAfter = bytesRead.get(); + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(1)); + int readLen = randomIntBetween(1, 8); + bytesReadBefore = bytesRead.get(); + if (randomBoolean()) { + readNBytes(test, readLen); + } else { + skipNBytes(test, readLen); + } + bytesReadAfter = bytesRead.get(); + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert no buffering + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + // mark + test.mark(randomIntBetween(1, bufferSize)); + // read one byte + bytesReadBefore = bytesRead.get(); + assertThat(test.read(), Matchers.not(-1)); + bytesReadAfter = bytesRead.get(); + // assert byte is "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(1)); + // assert byte is buffered + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - 1)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(1)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // read more bytes, up to buffer size bytes + readLen = randomIntBetween(1, bufferSize - 1); + bytesReadBefore = bytesRead.get(); + if (randomBoolean()) { + readNBytes(test, readLen); + } else { + skipNBytes(test, readLen); + } + bytesReadAfter = bytesRead.get(); + // assert byte is "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert byte is buffered + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - 1 - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(1 + readLen)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + assertThat(test.storeToBuffer, Matchers.is(true)); + } + + public void testMarkInvalidation() throws Exception { + int bufferSize = randomIntBetween(3, 128); + Tuple mockSourceTuple = getMockInfiniteInputStream(); + AtomicInteger bytesRead = mockSourceTuple.v1(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + assertThat(test.storeToBuffer, Matchers.is(false)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // mark + test.mark(randomIntBetween(1, bufferSize)); + // read all bytes to fill the mark buffer + int bytesReadBefore = bytesRead.get(); + // read enough to populate the full buffer space + int readLen = bufferSize; + if (randomBoolean()) { + readNBytes(test, readLen); + } else { + skipNBytes(test, readLen); + } + int bytesReadAfter = bytesRead.get(); + // assert byte is "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert byte is buffered + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(bufferSize)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + assertThat(test.storeToBuffer, Matchers.is(true)); + // read another one byte + bytesReadBefore = bytesRead.get(); + assertThat(test.read(), Matchers.not(-1)); + bytesReadAfter = bytesRead.get(); + // assert byte is "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(1)); + // assert mark is invalidated and no buffering is further performed + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + assertThat(test.storeToBuffer, Matchers.is(false)); + // read more bytes + bytesReadBefore = bytesRead.get(); + readLen = randomIntBetween(1, 2 * bufferSize); + if (randomBoolean()) { + readNBytes(test, readLen); + } else { + skipNBytes(test, readLen); + } + bytesReadAfter = bytesRead.get(); + // assert byte is "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert byte again is NOT buffered + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + assertThat(test.storeToBuffer, Matchers.is(false)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // assert reset does not work any more + IOException e = expectThrows(IOException.class, () -> { test.reset(); }); + assertThat(e.getMessage(), Matchers.is("Mark not called or has been invalidated")); + } + + public void testConsumeBufferUponReset() throws Exception { + int bufferSize = randomIntBetween(3, 128); + Tuple mockSourceTuple = getMockInfiniteInputStream(); + AtomicInteger bytesRead = mockSourceTuple.v1(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + // maybe read some bytes + readNBytes(test, randomFrom(0, randomInt(32))); + // mark + test.mark(randomIntBetween(1, bufferSize)); + // read less than bufferSize bytes + int bytesReadBefore = bytesRead.get(); + int readLen = randomIntBetween(1, bufferSize); + if (randomBoolean()) { + readNBytes(test, readLen); + } else { + skipNBytes(test, readLen); + } + int bytesReadAfter = bytesRead.get(); + // assert bytes are "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert buffer is populated + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // reset + test.reset(); + assertThat(test.replayFromBuffer, Matchers.is(true)); + assertThat(test.storeToBuffer, Matchers.is(true)); + // read again, from buffer this time + bytesReadBefore = bytesRead.get(); + int readLen2 = randomIntBetween(1, readLen); + if (randomBoolean()) { + readNBytes(test, readLen2); + } else { + skipNBytes(test, readLen2); + } + bytesReadAfter = bytesRead.get(); + // assert bytes are replayed from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(0)); + // assert buffer is consumed + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - readLen2)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(true)); + } + + public void testInvalidateMarkAfterReset() throws Exception { + int bufferSize = randomIntBetween(3, 128); + Tuple mockSourceTuple = getMockInfiniteInputStream(); + AtomicInteger bytesRead = mockSourceTuple.v1(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + // maybe read some bytes + readNBytes(test, randomFrom(0, randomInt(32))); + // mark + test.mark(randomIntBetween(1, bufferSize)); + // read less than bufferSize bytes + int bytesReadBefore = bytesRead.get(); + int readLen = randomIntBetween(1, bufferSize); + if (randomBoolean()) { + readNBytes(test, readLen); + } else { + skipNBytes(test, readLen); + } + int bytesReadAfter = bytesRead.get(); + // assert bytes are "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert buffer is populated + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // reset + test.reset(); + // assert signal for replay from buffer is toggled + assertThat(test.replayFromBuffer, Matchers.is(true)); + assertThat(test.storeToBuffer, Matchers.is(true)); + // assert bytes are still buffered + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + // read again, from buffer this time + bytesReadBefore = bytesRead.get(); + // read all bytes from the buffer + int readLen2 = readLen; + if (randomBoolean()) { + readNBytes(test, readLen2); + } else { + skipNBytes(test, readLen2); + } + bytesReadAfter = bytesRead.get(); + // assert bytes are replayed from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(0)); + // assert buffer is consumed + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(true)); + // read on, from the stream, until the mark buffer is full + bytesReadBefore = bytesRead.get(); + // read the remaining bytes to fill the buffer + int readLen3 = bufferSize - readLen; + if (randomBoolean()) { + readNBytes(test, readLen3); + } else { + skipNBytes(test, readLen3); + } + bytesReadAfter = bytesRead.get(); + // assert bytes are "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen3)); + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen3)); + assertThat(test.storeToBuffer, Matchers.is(true)); + if (readLen3 > 0) { + assertThat(test.replayFromBuffer, Matchers.is(false)); + } else { + assertThat(test.replayFromBuffer, Matchers.is(true)); + } + // read more bytes + bytesReadBefore = bytesRead.get(); + int readLen4 = randomIntBetween(1, 2 * bufferSize); + if (randomBoolean()) { + readNBytes(test, readLen4); + } else { + skipNBytes(test, readLen4); + } + bytesReadAfter = bytesRead.get(); + // assert byte is "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen4)); + // assert mark reset + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + assertThat(test.storeToBuffer, Matchers.is(false)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // assert reset does not work anymore + IOException e = expectThrows(IOException.class, () -> { test.reset(); }); + assertThat(e.getMessage(), Matchers.is("Mark not called or has been invalidated")); + } + + public void testMarkAfterResetWhileReplayingBuffer() throws Exception { + int bufferSize = randomIntBetween(8, 16); + Tuple mockSourceTuple = getMockInfiniteInputStream(); + AtomicInteger bytesRead = mockSourceTuple.v1(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + // maybe read some bytes + readNBytes(test, randomFrom(0, randomInt(32))); + // mark + test.mark(randomIntBetween(1, bufferSize)); + // read less than bufferSize bytes + int bytesReadBefore = bytesRead.get(); + int readLen = randomIntBetween(1, bufferSize); + if (randomBoolean()) { + readNBytes(test, readLen); + } else { + skipNBytes(test, readLen); + } + int bytesReadAfter = bytesRead.get(); + // assert bytes are "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert buffer is populated + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // reset + test.reset(); + assertThat(test.replayFromBuffer, Matchers.is(true)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + // read bytes after reset + for (int readLen2 = 1; readLen2 <= readLen; readLen2++) { + Tuple mockSourceTuple2 = getMockInfiniteInputStream(); + BufferOnMarkInputStream cloneTest = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + cloneBufferOnMarkStream(cloneTest, test); + AtomicInteger bytesRead2 = mockSourceTuple2.v1(); + // read again, from buffer this time, less than before + bytesReadBefore = bytesRead2.get(); + if (randomBoolean()) { + readNBytes(cloneTest, readLen2); + } else { + skipNBytes(cloneTest, readLen2); + } + bytesReadAfter = bytesRead2.get(); + // assert bytes are replayed from the buffer, and not read from the stream + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(0)); + // assert buffer is consumed + assertThat(cloneTest.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(cloneTest.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - readLen2)); + assertThat(cloneTest.storeToBuffer, Matchers.is(true)); + assertThat(cloneTest.replayFromBuffer, Matchers.is(true)); + // mark inside the buffer after reset + cloneTest.mark(randomIntBetween(1, bufferSize)); + assertThat(cloneTest.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen + readLen2)); + assertThat(cloneTest.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - readLen2)); + assertThat(cloneTest.storeToBuffer, Matchers.is(true)); + assertThat(cloneTest.replayFromBuffer, Matchers.is(true)); + // read until the buffer is filled + for (int readLen3 = 1; readLen3 <= readLen - readLen2; readLen3++) { + Tuple mockSourceTuple3 = getMockInfiniteInputStream(); + BufferOnMarkInputStream cloneTest3 = new BufferOnMarkInputStream(mockSourceTuple3.v2(), bufferSize); + cloneBufferOnMarkStream(cloneTest3, cloneTest); + AtomicInteger bytesRead3 = mockSourceTuple3.v1(); + // read again from buffer, after the mark inside the buffer + bytesReadBefore = bytesRead3.get(); + if (randomBoolean()) { + readNBytes(cloneTest3, readLen3); + } else { + skipNBytes(cloneTest3, readLen3); + } + bytesReadAfter = bytesRead3.get(); + // assert bytes are replayed from the buffer, and not read from the stream + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(0)); + // assert buffer is consumed completely + assertThat(cloneTest3.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen + readLen2)); + assertThat(cloneTest3.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - readLen2 - readLen3)); + assertThat(cloneTest3.storeToBuffer, Matchers.is(true)); + assertThat(cloneTest3.replayFromBuffer, Matchers.is(true)); + } + // read beyond the buffer can supply, but not more than it can accommodate + for (int readLen3 = readLen - readLen2 + 1; readLen3 <= bufferSize - readLen2; readLen3++) { + Tuple mockSourceTuple3 = getMockInfiniteInputStream(); + BufferOnMarkInputStream cloneTest3 = new BufferOnMarkInputStream(mockSourceTuple3.v2(), bufferSize); + cloneBufferOnMarkStream(cloneTest3, cloneTest); + AtomicInteger bytesRead3 = mockSourceTuple3.v1(); + // read again from buffer, after the mark inside the buffer + bytesReadBefore = bytesRead3.get(); + if (randomBoolean()) { + readNBytes(cloneTest3, readLen3); + } else { + skipNBytes(cloneTest3, readLen3); + } + bytesReadAfter = bytesRead3.get(); + // assert bytes are PARTLY replayed, PARTLY read from the stream + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen3 + readLen2 - readLen)); + // assert buffer is appended and fully replayed + assertThat(cloneTest3.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen3)); + assertThat(cloneTest3.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen3 + readLen2 - readLen)); + assertThat(cloneTest3.storeToBuffer, Matchers.is(true)); + assertThat(cloneTest3.replayFromBuffer, Matchers.is(false)); + } + } + } + + public void testMarkAfterResetAfterReplayingBuffer() throws Exception { + int bufferSize = randomIntBetween(8, 16); + Tuple mockSourceTuple = getMockInfiniteInputStream(); + AtomicInteger bytesRead = mockSourceTuple.v1(); + BufferOnMarkInputStream test = new BufferOnMarkInputStream(mockSourceTuple.v2(), bufferSize); + // maybe read some bytes + readNBytes(test, randomFrom(0, randomInt(32))); + // mark + test.mark(randomIntBetween(1, bufferSize)); + // read less than bufferSize bytes + int bytesReadBefore = bytesRead.get(); + int readLen = randomIntBetween(1, bufferSize); + if (randomBoolean()) { + readNBytes(test, readLen); + } else { + skipNBytes(test, readLen); + } + int bytesReadAfter = bytesRead.get(); + // assert bytes are "read" and not returned from the buffer + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen)); + // assert buffer is populated + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(false)); + // reset + test.reset(); + assertThat(test.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen)); + assertThat(test.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + assertThat(test.storeToBuffer, Matchers.is(true)); + assertThat(test.replayFromBuffer, Matchers.is(true)); + for (int readLen2 = readLen + 1; readLen2 <= bufferSize; readLen2++) { + Tuple mockSourceTuple2 = getMockInfiniteInputStream(); + BufferOnMarkInputStream test2 = new BufferOnMarkInputStream(mockSourceTuple2.v2(), bufferSize); + cloneBufferOnMarkStream(test2, test); + AtomicInteger bytesRead2 = mockSourceTuple2.v1(); + // read again, more than before + bytesReadBefore = bytesRead2.get(); + if (randomBoolean()) { + readNBytes(test2, readLen2); + } else { + skipNBytes(test2, readLen2); + } + bytesReadAfter = bytesRead2.get(); + // assert bytes are PARTLY replayed, PARTLY read from the stream + assertThat(bytesReadAfter - bytesReadBefore, Matchers.is(readLen2 - readLen)); + // assert buffer is appended and fully replayed + assertThat(test2.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize - readLen2)); + assertThat(test2.storeToBuffer, Matchers.is(true)); + assertThat(test2.replayFromBuffer, Matchers.is(false)); + // mark + test2.mark(randomIntBetween(1, bufferSize)); + assertThat(test2.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(bufferSize)); + assertThat(test2.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + assertThat(test2.storeToBuffer, Matchers.is(true)); + assertThat(test2.replayFromBuffer, Matchers.is(false)); + } + } + + public void testNoMockSimpleMarkResetAtBeginning() throws Exception { + for (int length = 1; length <= 8; length++) { + for (int mark = 1; mark <= length; mark++) { + try (BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, length), mark)) { + in.mark(mark); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(mark)); + byte[] test1 = readNBytes(in, mark); + assertArray(0, test1); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + in.reset(); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + byte[] test2 = readNBytes(in, mark); + assertArray(0, test2); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + } + } + } + } + + public void testNoMockMarkResetAtBeginning() throws Exception { + for (int length = 1; length <= 8; length++) { + try (BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, length), length)) { + in.mark(length); + // increasing length read/reset + for (int readLen = 1; readLen <= length; readLen++) { + byte[] test1 = readNBytes(in, readLen); + assertArray(0, test1); + in.reset(); + } + } + try (BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, length), length)) { + in.mark(length); + // decreasing length read/reset + for (int readLen = length; readLen >= 1; readLen--) { + byte[] test1 = readNBytes(in, readLen); + assertArray(0, test1); + in.reset(); + } + } + } + } + + public void testNoMockSimpleMarkResetEverywhere() throws Exception { + for (int length = 1; length <= 10; length++) { + for (int offset = 0; offset < length; offset++) { + for (int mark = 1; mark <= length - offset; mark++) { + try ( + BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, length), mark) + ) { + // skip first offset bytes + readNBytes(in, offset); + in.mark(mark); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(mark)); + byte[] test1 = readNBytes(in, mark); + assertArray(offset, test1); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + in.reset(); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + byte[] test2 = readNBytes(in, mark); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(0)); + assertArray(offset, test2); + } + } + } + } + } + + public void testNoMockMarkResetEverywhere() throws Exception { + for (int length = 1; length <= 8; length++) { + for (int offset = 0; offset < length; offset++) { + try ( + BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, length), length) + ) { + // skip first offset bytes + readNBytes(in, offset); + in.mark(length); + // increasing read lengths + for (int readLen = 1; readLen <= length - offset; readLen++) { + byte[] test = readNBytes(in, readLen); + assertArray(offset, test); + in.reset(); + } + } + try ( + BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, length), length) + ) { + // skip first offset bytes + readNBytes(in, offset); + in.mark(length); + // decreasing read lengths + for (int readLen = length - offset; readLen >= 1; readLen--) { + byte[] test = readNBytes(in, readLen); + assertArray(offset, test); + in.reset(); + } + } + } + } + } + + public void testNoMockDoubleMarkEverywhere() throws Exception { + for (int length = 1; length <= 16; length++) { + for (int offset = 0; offset < length; offset++) { + for (int readLen = 1; readLen <= length - offset; readLen++) { + for (int markLen = 1; markLen <= length - offset; markLen++) { + try ( + BufferOnMarkInputStream in = new BufferOnMarkInputStream( + new NoMarkByteArrayInputStream(testArray, 0, length), + length + ) + ) { + readNBytes(in, offset); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + // first mark + in.mark(length - offset); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + byte[] test = readNBytes(in, readLen); + assertArray(offset, test); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - readLen)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + // reset to first + in.reset(); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - readLen)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + // advance before/after the first read length + test = readNBytes(in, markLen); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - Math.max(readLen, markLen))); + if (markLen <= readLen) { + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - markLen)); + } else { + assertThat(in.replayFromBuffer, Matchers.is(false)); + } + assertArray(offset, test); + // second mark + in.mark(length - offset - markLen); + if (markLen <= readLen) { + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - readLen + markLen)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - markLen)); + } else { + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + } + for (int readLen2 = 1; readLen2 <= length - offset - markLen; readLen2++) { + byte[] test2 = readNBytes(in, readLen2); + if (markLen + readLen2 <= readLen) { + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - readLen + markLen)); + assertThat(in.replayFromBuffer, Matchers.is(true)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - markLen - readLen2)); + } else { + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - readLen2)); + assertThat(in.replayFromBuffer, Matchers.is(false)); + } + assertArray(offset + markLen, test2); + in.reset(); + assertThat(in.replayFromBuffer, Matchers.is(true)); + if (markLen + readLen2 <= readLen) { + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - readLen + markLen)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen - markLen)); + } else { + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(length - readLen2)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen2)); + } + } + } + } + } + } + } + } + + public void testNoMockMarkWithoutReset() throws Exception { + int maxMark = 8; + BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, testArray.length), maxMark); + int offset = 0; + while (offset < testArray.length) { + int readLen = Math.min(1 + Randomness.get().nextInt(maxMark), testArray.length - offset); + in.mark(Randomness.get().nextInt(readLen)); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(maxMark)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(0)); + byte[] test = readNBytes(in, readLen); + assertThat(in.ringBuffer.getAvailableToWriteByteCount(), Matchers.is(maxMark - readLen)); + assertThat(in.ringBuffer.getAvailableToReadByteCount(), Matchers.is(readLen)); + assertArray(offset, test); + offset += readLen; + } + } + + public void testNoMockThreeMarkResetMarkSteps() throws Exception { + int length = randomIntBetween(8, 16); + int stepLen = randomIntBetween(4, 8); + BufferOnMarkInputStream in = new BufferOnMarkInputStream(new NoMarkByteArrayInputStream(testArray, 0, length), stepLen); + testMarkResetMarkStep(in, 0, length, stepLen, 2); + } + + private void testMarkResetMarkStep(BufferOnMarkInputStream stream, int offset, int length, int stepLen, int step) throws Exception { + stream.mark(stepLen); + for (int readLen = 1; readLen <= Math.min(stepLen, length - offset); readLen++) { + for (int markLen = 1; markLen <= Math.min(stepLen, length - offset); markLen++) { + BufferOnMarkInputStream cloneStream = cloneBufferOnMarkStream(stream); + // read ahead + byte[] test = readNBytes(cloneStream, readLen); + assertArray(offset, test); + // reset back + cloneStream.reset(); + // read ahead different length + test = readNBytes(cloneStream, markLen); + assertArray(offset, test); + if (step > 0) { + testMarkResetMarkStep(cloneStream, offset + markLen, length, stepLen, step - 1); + } + } + } + } + + private BufferOnMarkInputStream cloneBufferOnMarkStream(BufferOnMarkInputStream orig) { + int origOffset = ((NoMarkByteArrayInputStream) orig.source).getPos(); + int origLen = ((NoMarkByteArrayInputStream) orig.source).getCount(); + BufferOnMarkInputStream cloneStream = new BufferOnMarkInputStream( + new NoMarkByteArrayInputStream(testArray, origOffset, origLen - origOffset), + orig.ringBuffer.getBufferSize() + ); + if (orig.ringBuffer.buffer != null) { + cloneStream.ringBuffer.buffer = Arrays.copyOf(orig.ringBuffer.buffer, orig.ringBuffer.buffer.length); + } else { + cloneStream.ringBuffer.buffer = null; + } + cloneStream.ringBuffer.head = orig.ringBuffer.head; + cloneStream.ringBuffer.tail = orig.ringBuffer.tail; + cloneStream.ringBuffer.position = orig.ringBuffer.position; + cloneStream.storeToBuffer = orig.storeToBuffer; + cloneStream.replayFromBuffer = orig.replayFromBuffer; + cloneStream.closed = orig.closed; + return cloneStream; + } + + private void cloneBufferOnMarkStream(BufferOnMarkInputStream clone, BufferOnMarkInputStream orig) { + if (orig.ringBuffer.buffer != null) { + clone.ringBuffer.buffer = Arrays.copyOf(orig.ringBuffer.buffer, orig.ringBuffer.buffer.length); + } else { + clone.ringBuffer.buffer = null; + } + clone.ringBuffer.head = orig.ringBuffer.head; + clone.ringBuffer.tail = orig.ringBuffer.tail; + clone.ringBuffer.position = orig.ringBuffer.position; + clone.storeToBuffer = orig.storeToBuffer; + clone.replayFromBuffer = orig.replayFromBuffer; + clone.closed = orig.closed; + } + + private void assertArray(int offset, byte[] test) { + for (int i = 0; i < test.length; i++) { + Assert.assertThat(test[i], Matchers.is(testArray[offset + i])); + } + } + + private Tuple getMockInfiniteInputStream() throws IOException { + InputStream mockSource = mock(InputStream.class); + AtomicInteger bytesRead = new AtomicInteger(0); + when(mockSource.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())).thenAnswer( + invocationOnMock -> { + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + int bytesCount = 1 + Randomness.get().nextInt(len); + bytesRead.addAndGet(bytesCount); + return bytesCount; + } + } + ); + when(mockSource.read()).thenAnswer(invocationOnMock -> { + bytesRead.incrementAndGet(); + return Randomness.get().nextInt(256); + }); + when(mockSource.skip(org.mockito.Matchers.anyLong())).thenAnswer(invocationOnMock -> { + final long n = (long) invocationOnMock.getArguments()[0]; + if (n <= 0) { + return 0; + } + int bytesSkipped = 1 + Randomness.get().nextInt(Math.toIntExact(n)); + bytesRead.addAndGet(bytesSkipped); + return bytesSkipped; + }); + when(mockSource.available()).thenReturn(1 + Randomness.get().nextInt(32)); + when(mockSource.markSupported()).thenReturn(false); + return new Tuple<>(bytesRead, mockSource); + } + + private static void skipNBytes(InputStream in, long n) throws IOException { + if (n > 0) { + long ns = in.skip(n); + if (ns >= 0 && ns < n) { // skipped too few bytes + // adjust number to skip + n -= ns; + // read until requested number skipped or EOS reached + while (n > 0 && in.read() != -1) { + n--; + } + // if not enough skipped, then EOFE + if (n != 0) { + throw new EOFException(); + } + } else if (ns != n) { // skipped negative or too many bytes + throw new IOException("Unable to skip exactly"); + } + } + } + + static class NoMarkByteArrayInputStream extends ByteArrayInputStream { + + NoMarkByteArrayInputStream(byte[] buf) { + super(buf); + } + + NoMarkByteArrayInputStream(byte[] buf, int offset, int length) { + super(buf, offset, length); + } + + int getPos() { + return pos; + } + + int getCount() { + return count; + } + + @Override + public void mark(int readlimit) {} + + @Override + public boolean markSupported() { + return false; + } + + @Override + public void reset() { + throw new IllegalStateException("Mark not called or has been invalidated"); + } + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/ChainingInputStreamTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/ChainingInputStreamTests.java new file mode 100644 index 0000000000000..4a8d1b017156f --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/ChainingInputStreamTests.java @@ -0,0 +1,1173 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.mockito.Mockito; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import static org.elasticsearch.repositories.encrypted.EncryptionPacketsInputStreamTests.readAllBytes; +import static org.elasticsearch.repositories.encrypted.EncryptionPacketsInputStreamTests.readNBytes; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ChainingInputStreamTests extends ESTestCase { + + public void testChainComponentsWhenUsingFactoryMethod() throws Exception { + InputStream input1 = mock(InputStream.class); + when(input1.markSupported()).thenReturn(true); + when(input1.read()).thenReturn(randomIntBetween(0, 255)); + InputStream input2 = mock(InputStream.class); + when(input2.markSupported()).thenReturn(true); + when(input2.read()).thenReturn(randomIntBetween(0, 255)); + + ChainingInputStream chain = ChainingInputStream.chain(input1, input2); + + chain.read(); + verify(input1).read(); + verify(input2, times(0)).read(); + + when(input1.read()).thenReturn(-1); + chain.read(); + verify(input1, times(2)).read(); + verify(input1, times(0)).close(); + verify(input2).read(); + + when(input2.read()).thenReturn(-1); + chain.read(); + verify(input1, times(2)).read(); + verify(input2, times(2)).read(); + verify(input1, times(0)).close(); + verify(input2, times(0)).close(); + + chain.close(); + verify(input1).close(); + verify(input2).close(); + } + + public void testMarkAndResetWhenUsingFactoryMethod() throws Exception { + InputStream input1 = mock(InputStream.class); + when(input1.markSupported()).thenReturn(true); + when(input1.read()).thenReturn(randomIntBetween(0, 255)); + InputStream input2 = mock(InputStream.class); + when(input2.markSupported()).thenReturn(true); + when(input2.read()).thenReturn(randomIntBetween(0, 255)); + + ChainingInputStream chain = ChainingInputStream.chain(input1, input2); + verify(input1, times(1)).mark(anyInt()); + verify(input2, times(1)).mark(anyInt()); + + // mark at the beginning + chain.mark(randomIntBetween(1, 32)); + verify(input1, times(1)).mark(anyInt()); + verify(input2, times(1)).mark(anyInt()); + + verify(input1, times(0)).reset(); + chain.read(); + verify(input1, times(1)).reset(); + chain.reset(); + verify(input1, times(0)).close(); + verify(input1, times(1)).reset(); + chain.read(); + verify(input1, times(2)).reset(); + + // mark at the first component + chain.mark(randomIntBetween(1, 32)); + verify(input1, times(2)).mark(anyInt()); + verify(input2, times(1)).mark(anyInt()); + + when(input1.read()).thenReturn(-1); + chain.read(); + verify(input1, times(0)).close(); + chain.reset(); + verify(input1, times(3)).reset(); + + chain.read(); + verify(input2, times(2)).reset(); + + // mark at the second component + chain.mark(randomIntBetween(1, 32)); + verify(input1, times(2)).mark(anyInt()); + verify(input2, times(2)).mark(anyInt()); + + when(input2.read()).thenReturn(-1); + chain.read(); + verify(input1, times(0)).close(); + verify(input2, times(0)).close(); + chain.reset(); + verify(input2, times(3)).reset(); + + chain.close(); + verify(input1, times(1)).close(); + verify(input2, times(1)).close(); + } + + public void testSkipWithinComponent() throws Exception { + byte[] b1 = randomByteArrayOfLength(randomIntBetween(2, 16)); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + return new ByteArrayInputStream(b1); + } else { + return null; + } + } + }; + int prefix = randomIntBetween(0, b1.length - 2); + readNBytes(test, prefix); + // skip less bytes than the component has + int nSkip1 = randomInt(b1.length - prefix); + long nSkip = test.skip(nSkip1); + assertThat((int) nSkip, Matchers.is(nSkip1)); + int nSkip2 = b1.length - prefix - nSkip1 + randomIntBetween(1, 8); + // skip more bytes than the component has + nSkip = test.skip(nSkip2); + assertThat((int) nSkip, Matchers.is(b1.length - prefix - nSkip1)); + } + + public void testSkipAcrossComponents() throws Exception { + byte[] b1 = randomByteArrayOfLength(randomIntBetween(1, 16)); + byte[] b2 = randomByteArrayOfLength(randomIntBetween(1, 16)); + ChainingInputStream test = new ChainingInputStream() { + final Iterator iter = Arrays.asList(new ByteArrayInputStream(b1), new ByteArrayInputStream(b2)) + .iterator(); + + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (iter.hasNext()) { + return iter.next(); + } else { + return null; + } + } + }; + long skipArg = b1.length + randomIntBetween(1, b2.length); + long nSkip = test.skip(skipArg); + assertThat(nSkip, Matchers.is(skipArg)); + byte[] rest = readAllBytes(test); + assertThat((long) rest.length, Matchers.is(b1.length + b2.length - nSkip)); + for (int i = rest.length - 1; i >= 0; i--) { + assertThat(rest[i], Matchers.is(b2[i + (int) nSkip - b1.length])); + } + } + + public void testEmptyChain() throws Exception { + // chain is empty because it doesn't have any components + ChainingInputStream emptyStream = newEmptyStream(false); + assertThat(emptyStream.read(), Matchers.is(-1)); + emptyStream = newEmptyStream(false); + byte[] b = randomByteArrayOfLength(randomIntBetween(1, 8)); + int off = randomInt(b.length - 1); + assertThat(emptyStream.read(b, off, b.length - off), Matchers.is(-1)); + emptyStream = newEmptyStream(false); + assertThat(emptyStream.available(), Matchers.is(0)); + emptyStream = newEmptyStream(false); + assertThat(emptyStream.skip(randomIntBetween(1, 32)), Matchers.is(0L)); + // chain is empty because all its components are empty + emptyStream = newEmptyStream(true); + assertThat(emptyStream.read(), Matchers.is(-1)); + emptyStream = newEmptyStream(true); + b = randomByteArrayOfLength(randomIntBetween(1, 8)); + off = randomInt(b.length - 1); + assertThat(emptyStream.read(b, off, b.length - off), Matchers.is(-1)); + emptyStream = newEmptyStream(true); + assertThat(emptyStream.available(), Matchers.is(0)); + emptyStream = newEmptyStream(true); + assertThat(emptyStream.skip(randomIntBetween(1, 32)), Matchers.is(0L)); + } + + public void testClose() throws Exception { + ChainingInputStream test1 = newEmptyStream(randomBoolean()); + test1.close(); + IOException e = expectThrows(IOException.class, () -> { test1.read(); }); + assertThat(e.getMessage(), Matchers.is("Stream is closed")); + ChainingInputStream test2 = newEmptyStream(randomBoolean()); + test2.close(); + byte[] b = randomByteArrayOfLength(randomIntBetween(2, 9)); + int off = randomInt(b.length - 2); + e = expectThrows(IOException.class, () -> { test2.read(b, off, randomInt(b.length - off - 1)); }); + assertThat(e.getMessage(), Matchers.is("Stream is closed")); + ChainingInputStream test3 = newEmptyStream(randomBoolean()); + test3.close(); + e = expectThrows(IOException.class, () -> { test3.skip(randomInt(31)); }); + assertThat(e.getMessage(), Matchers.is("Stream is closed")); + ChainingInputStream test4 = newEmptyStream(randomBoolean()); + test4.close(); + e = expectThrows(IOException.class, () -> { test4.available(); }); + assertThat(e.getMessage(), Matchers.is("Stream is closed")); + ChainingInputStream test5 = newEmptyStream(randomBoolean()); + test5.close(); + e = expectThrows(IOException.class, () -> { test5.reset(); }); + assertThat(e.getMessage(), Matchers.is("Stream is closed")); + ChainingInputStream test6 = newEmptyStream(randomBoolean()); + test6.close(); + try { + test6.mark(randomInt()); + } catch (Exception e1) { + assumeNoException("mark on a closed stream should not throw", e1); + } + } + + public void testInitialComponentArgumentIsNull() throws Exception { + AtomicReference initialInputStream = new AtomicReference<>(); + AtomicBoolean nextCalled = new AtomicBoolean(false); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + initialInputStream.set(currentComponentIn); + nextCalled.set(true); + return null; + } + }; + assertThat(test.read(), Matchers.is(-1)); + assertThat(nextCalled.get(), Matchers.is(true)); + assertThat(initialInputStream.get(), Matchers.nullValue()); + } + + public void testChaining() throws Exception { + int componentCount = randomIntBetween(2, 9); + ByteBuffer testSource = ByteBuffer.allocate(componentCount); + TestInputStream[] sourceComponents = new TestInputStream[componentCount]; + for (int i = 0; i < sourceComponents.length; i++) { + byte[] b = randomByteArrayOfLength(randomInt(1)); + testSource.put(b); + sourceComponents[i] = new TestInputStream(b); + } + ChainingInputStream test = new ChainingInputStream() { + int i = 0; + + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (i == 0) { + assertThat(currentComponentIn, Matchers.nullValue()); + return sourceComponents[i++]; + } else if (i < sourceComponents.length) { + assertThat(((TestInputStream) currentComponentIn).closed.get(), Matchers.is(true)); + assertThat(currentComponentIn, Matchers.is(sourceComponents[i - 1])); + return sourceComponents[i++]; + } else if (i == sourceComponents.length) { + assertThat(((TestInputStream) currentComponentIn).closed.get(), Matchers.is(true)); + assertThat(currentComponentIn, Matchers.is(sourceComponents[i - 1])); + i++; + return null; + } else { + throw new IllegalStateException(); + } + } + + @Override + public boolean markSupported() { + return false; + } + }; + byte[] testArr = readAllBytes(test); + byte[] ref = testSource.array(); + // testArr and ref should be equal, but ref might have trailing zeroes + for (int i = 0; i < testArr.length; i++) { + assertThat(testArr[i], Matchers.is(ref[i])); + } + } + + public void testEmptyInputStreamComponents() throws Exception { + // leading single empty stream + Tuple test = testEmptyComponentsInChain(3, Arrays.asList(0)); + byte[] result = readAllBytes(test.v1()); + assertThat(result.length, Matchers.is(test.v2().length)); + for (int i = 0; i < result.length; i++) { + assertThat(result[i], Matchers.is(test.v2()[i])); + } + // leading double empty streams + test = testEmptyComponentsInChain(3, Arrays.asList(0, 1)); + result = readAllBytes(test.v1()); + assertThat(result.length, Matchers.is(test.v2().length)); + for (int i = 0; i < result.length; i++) { + assertThat(result[i], Matchers.is(test.v2()[i])); + } + // trailing single empty stream + test = testEmptyComponentsInChain(3, Arrays.asList(2)); + result = readAllBytes(test.v1()); + assertThat(result.length, Matchers.is(test.v2().length)); + for (int i = 0; i < result.length; i++) { + assertThat(result[i], Matchers.is(test.v2()[i])); + } + // trailing double empty stream + test = testEmptyComponentsInChain(3, Arrays.asList(1, 2)); + result = readAllBytes(test.v1()); + assertThat(result.length, Matchers.is(test.v2().length)); + for (int i = 0; i < result.length; i++) { + assertThat(result[i], Matchers.is(test.v2()[i])); + } + // middle single empty stream + test = testEmptyComponentsInChain(3, Arrays.asList(1)); + result = readAllBytes(test.v1()); + assertThat(result.length, Matchers.is(test.v2().length)); + for (int i = 0; i < result.length; i++) { + assertThat(result[i], Matchers.is(test.v2()[i])); + } + // leading and trailing empty streams + test = testEmptyComponentsInChain(3, Arrays.asList(0, 2)); + result = readAllBytes(test.v1()); + assertThat(result.length, Matchers.is(test.v2().length)); + for (int i = 0; i < result.length; i++) { + assertThat(result[i], Matchers.is(test.v2()[i])); + } + // all streams are empty + test = testEmptyComponentsInChain(3, Arrays.asList(0, 1, 2)); + result = readAllBytes(test.v1()); + assertThat(result.length, Matchers.is(0)); + } + + public void testNullComponentTerminatesChain() throws Exception { + TestInputStream[] sourceComponents = new TestInputStream[3]; + TestInputStream[] chainComponents = new TestInputStream[5]; + byte[] b1 = randomByteArrayOfLength(randomIntBetween(1, 2)); + sourceComponents[0] = new TestInputStream(b1); + sourceComponents[1] = null; + byte[] b2 = randomByteArrayOfLength(randomIntBetween(1, 2)); + sourceComponents[2] = new TestInputStream(b2); + ChainingInputStream test = new ChainingInputStream() { + int i = 0; + + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + chainComponents[i] = (TestInputStream) currentComponentIn; + if (i < sourceComponents.length) { + return sourceComponents[i++]; + } else { + i++; + return null; + } + } + + @Override + public boolean markSupported() { + return false; + } + }; + assertThat(readAllBytes(test), Matchers.equalTo(b1)); + assertThat(chainComponents[0], Matchers.nullValue()); + assertThat(chainComponents[1], Matchers.is(sourceComponents[0])); + assertThat(chainComponents[1].closed.get(), Matchers.is(true)); + assertThat(chainComponents[2], Matchers.nullValue()); + assertThat(chainComponents[3], Matchers.nullValue()); + } + + public void testCallsForwardToCurrentComponent() throws Exception { + InputStream mockCurrentIn = mock(InputStream.class); + when(mockCurrentIn.markSupported()).thenReturn(true); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + return mockCurrentIn; + } else { + throw new IllegalStateException(); + } + } + }; + // verify "byte-wise read" is proxied to the current component stream + when(mockCurrentIn.read()).thenReturn(randomInt(255)); + test.read(); + verify(mockCurrentIn).read(); + // verify "array read" is proxied to the current component stream + when(mockCurrentIn.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())) + .thenAnswer(invocationOnMock -> { + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + // partial read return + int bytesCount = randomIntBetween(1, len); + return bytesCount; + } + }); + byte[] b = randomByteArrayOfLength(randomIntBetween(2, 33)); + int len = randomIntBetween(1, b.length - 1); + int offset = randomInt(b.length - len - 1); + test.read(b, offset, len); + verify(mockCurrentIn).read(Mockito.eq(b), Mockito.eq(offset), Mockito.eq(len)); + // verify "skip" is proxied to the current component stream + long skipCount = randomIntBetween(1, 3); + test.skip(skipCount); + verify(mockCurrentIn).skip(Mockito.eq(skipCount)); + // verify "available" is proxied to the current component stream + test.available(); + verify(mockCurrentIn).available(); + } + + public void testEmptyReadAsksForNext() throws Exception { + InputStream mockCurrentIn = mock(InputStream.class); + when(mockCurrentIn.markSupported()).thenReturn(true); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + return mockCurrentIn; + } + }; + test.currentIn = new ByteArrayInputStream(new byte[0]); + when(mockCurrentIn.read()).thenReturn(randomInt(255)); + test.read(); + verify(mockCurrentIn).read(); + // test "array read" + test.currentIn = new ByteArrayInputStream(new byte[0]); + when(mockCurrentIn.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())) + .thenAnswer(invocationOnMock -> { + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + int bytesCount = randomIntBetween(1, len); + return bytesCount; + } + }); + byte[] b = new byte[randomIntBetween(2, 33)]; + int len = randomIntBetween(1, b.length - 1); + int offset = randomInt(b.length - len - 1); + test.read(b, offset, len); + verify(mockCurrentIn).read(Mockito.eq(b), Mockito.eq(offset), Mockito.eq(len)); + } + + public void testReadAll() throws Exception { + byte[] b = randomByteArrayOfLength(randomIntBetween(2, 33)); + int splitIdx = randomInt(b.length - 2); + ByteArrayInputStream first = new ByteArrayInputStream(b, 0, splitIdx + 1); + ByteArrayInputStream second = new ByteArrayInputStream(b, splitIdx + 1, b.length - splitIdx - 1); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentElementIn) throws IOException { + if (currentElementIn == null) { + return first; + } else if (currentElementIn == first) { + return second; + } else if (currentElementIn == second) { + return null; + } else { + throw new IllegalArgumentException(); + } + } + }; + byte[] result = readAllBytes(test); + assertThat(result, Matchers.equalTo(b)); + } + + public void testMarkAtBeginning() throws Exception { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + when(mockIn.read()).thenAnswer(invocationOnMock -> randomInt(255)); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + return mockIn; + } else { + return null; + } + } + }; + assertThat(test.currentIn, Matchers.nullValue()); + // mark at the beginning + assertThat(test.markIn, Matchers.nullValue()); + test.mark(randomInt(63)); + assertThat(test.markIn, Matchers.nullValue()); + // another mark is a no-op + test.mark(randomInt(63)); + assertThat(test.markIn, Matchers.nullValue()); + // read does not change the marK + test.read(); + assertThat(test.currentIn, Matchers.is(mockIn)); + // mark reference is still unchanged + assertThat(test.markIn, Matchers.nullValue()); + // read reaches end + when(mockIn.read()).thenReturn(-1); + test.read(); + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + verify(mockIn).close(); + // mark reference is still unchanged + assertThat(test.markIn, Matchers.nullValue()); + } + + public void testMarkAtEnding() throws Exception { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + when(mockIn.read()).thenAnswer(invocationOnMock -> randomFrom(-1, randomInt(255))); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + return mockIn; + } else { + return null; + } + } + }; + // read all bytes + while (test.read() != -1) { + } + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // mark is null (beginning) + assertThat(test.markIn, Matchers.nullValue()); + test.mark(randomInt(255)); + assertThat(test.markIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // another mark is a no-op + test.mark(randomInt(255)); + assertThat(test.markIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + } + + public void testSingleMarkAnywhere() throws Exception { + Supplier mockInputStreamSupplier = () -> { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + try { + when(mockIn.read()).thenAnswer(invocationOnMock -> randomFrom(-1, randomInt(1))); + when(mockIn.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())) + .thenAnswer(invocationOnMock -> { + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + if (randomBoolean()) { + return -1; + } else { + // partial read return + return randomIntBetween(1, len); + } + } + }); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return mockIn; + }; + AtomicBoolean chainingInputStreamEOF = new AtomicBoolean(false); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (chainingInputStreamEOF.get()) { + return null; + } else { + return mockInputStreamSupplier.get(); + } + } + }; + // possibly skips over several components + for (int i = 0; i < randomIntBetween(4, 16); i++) { + readNBytes(test, randomInt(63)); + } + InputStream currentIn = test.currentIn; + int readLimit = randomInt(63); + test.mark(readLimit); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(currentIn).mark(Mockito.eq(readLimit)); + // mark again, same position + int readLimit2 = randomInt(63); + test.mark(readLimit2); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + if (readLimit != readLimit2) { + verify(currentIn).mark(Mockito.eq(readLimit2)); + } else { + verify(currentIn, times(2)).mark(Mockito.eq(readLimit)); + } + // read more (possibly moving on to a new component) + readNBytes(test, randomInt(63)); + // mark does not budge + assertThat(test.markIn, Matchers.is(currentIn)); + // read until the end + chainingInputStreamEOF.set(true); + readAllBytes(test); + // current component is at the end + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // mark is still put + assertThat(test.markIn, Matchers.is(currentIn)); + verify(test.markIn, never()).close(); + // but close also closes the mark + test.close(); + verify(test.markIn).close(); + } + + public void testMarkOverwritesPreviousMark() throws Exception { + AtomicBoolean chainingInputStreamEOF = new AtomicBoolean(false); + Supplier mockInputStreamSupplier = () -> { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + try { + // single byte read never returns "-1" so it never advances component + when(mockIn.read()).thenAnswer(invocationOnMock -> randomInt(255)); + when(mockIn.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())) + .thenAnswer(invocationOnMock -> { + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + if (randomBoolean()) { + return -1; + } else { + // partial read return + return randomIntBetween(1, len); + } + } + }); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return mockIn; + }; + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (chainingInputStreamEOF.get()) { + return null; + } else { + return mockInputStreamSupplier.get(); + } + } + }; + // possibly skips over several components + for (int i = 0; i < randomIntBetween(4, 16); i++) { + readNBytes(test, randomInt(63)); + } + InputStream currentIn = test.currentIn; + int readLimit = randomInt(63); + test.mark(readLimit); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(test.markIn).mark(Mockito.eq(readLimit)); + // read more within the same component + for (int i = 0; i < randomIntBetween(4, 16); i++) { + test.read(); + } + // mark does not budge + assertThat(test.markIn, Matchers.is(currentIn)); + // mark again + int readLimit2 = randomInt(63); + test.mark(readLimit2); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(currentIn, never()).close(); + if (readLimit != readLimit2) { + verify(currentIn).mark(Mockito.eq(readLimit2)); + } else { + verify(currentIn, times(2)).mark(Mockito.eq(readLimit)); + } + // read more while switching the component + for (int i = 0; i < randomIntBetween(4, 16) || test.currentIn == currentIn; i++) { + readNBytes(test, randomInt(63)); + } + // mark does not budge + assertThat(test.markIn, Matchers.is(currentIn)); + // mark again + readLimit = randomInt(63); + test.mark(readLimit); + assertThat(test.markIn, Matchers.is(test.currentIn)); + // previous mark closed + verify(currentIn).close(); + verify(test.markIn).mark(Mockito.eq(readLimit)); + InputStream markIn = test.markIn; + // read until the end + chainingInputStreamEOF.set(true); + readAllBytes(test); + // current component is at the end + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // mark is still put + assertThat(test.markIn, Matchers.is(markIn)); + verify(test.markIn, never()).close(); + // mark at the end + readLimit = randomInt(63); + test.mark(readLimit); + assertThat(test.markIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + verify(markIn).close(); + } + + public void testResetAtBeginning() throws Exception { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + when(mockIn.read()).thenAnswer(invocationOnMock -> randomInt(255)); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + return mockIn; + } else { + return null; + } + } + }; + assertThat(test.currentIn, Matchers.nullValue()); + assertThat(test.markIn, Matchers.nullValue()); + if (randomBoolean()) { + // mark at the beginning + test.mark(randomInt(63)); + assertThat(test.markIn, Matchers.nullValue()); + } + // reset immediately + test.reset(); + assertThat(test.currentIn, Matchers.nullValue()); + // read does not change the marK + test.read(); + assertThat(test.currentIn, Matchers.is(mockIn)); + // mark reference is still unchanged + assertThat(test.markIn, Matchers.nullValue()); + // reset back to beginning + test.reset(); + verify(mockIn).close(); + assertThat(test.currentIn, Matchers.nullValue()); + // read reaches end + when(mockIn.read()).thenReturn(-1); + test.read(); + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // mark reference is still unchanged + assertThat(test.markIn, Matchers.nullValue()); + // reset back to beginning + test.reset(); + assertThat(test.currentIn, Matchers.nullValue()); + } + + public void testResetAtEnding() throws Exception { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + when(mockIn.read()).thenAnswer(invocationOnMock -> randomFrom(-1, randomInt(255))); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + return mockIn; + } else { + return null; + } + } + }; + // read all bytes + while (test.read() != -1) { + } + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // mark is null (beginning) + assertThat(test.markIn, Matchers.nullValue()); + test.mark(randomInt(255)); + assertThat(test.markIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // reset + test.reset(); + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + assertThat(test.read(), Matchers.is(-1)); + // another mark is a no-op + test.mark(randomInt(255)); + assertThat(test.markIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + assertThat(test.read(), Matchers.is(-1)); + } + + public void testResetForSingleMarkAnywhere() throws Exception { + Supplier mockInputStreamSupplier = () -> { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + try { + // single byte read never returns "-1" so it never advances component + when(mockIn.read()).thenAnswer(invocationOnMock -> randomInt(255)); + when(mockIn.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())) + .thenAnswer(invocationOnMock -> { + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + if (randomBoolean()) { + return -1; + } else { + // partial read return + return randomIntBetween(1, len); + } + } + }); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return mockIn; + }; + AtomicBoolean chainingInputStreamEOF = new AtomicBoolean(false); + AtomicReference nextComponentArg = new AtomicReference<>(); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (nextComponentArg.get() != null) { + Assert.assertThat(currentComponentIn, Matchers.is(nextComponentArg.get())); + nextComponentArg.set(null); + } + if (chainingInputStreamEOF.get()) { + return null; + } else { + return mockInputStreamSupplier.get(); + } + } + }; + // possibly skips over several components + for (int i = 0; i < randomIntBetween(4, 16); i++) { + readNBytes(test, randomInt(63)); + } + InputStream currentIn = test.currentIn; + int readLimit = randomInt(63); + test.mark(readLimit); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(currentIn).mark(Mockito.eq(readLimit)); + // read more without moving to a new component + for (int i = 0; i < randomIntBetween(4, 16); i++) { + test.read(); + } + // first reset + test.reset(); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(test.currentIn, never()).close(); + verify(test.currentIn).reset(); + // read more, moving on to a new component + for (int i = 0; i < randomIntBetween(4, 16) || test.currentIn == currentIn; i++) { + readNBytes(test, randomInt(63)); + } + // mark does not budge + assertThat(test.markIn, Matchers.is(currentIn)); + assertThat(test.currentIn, Matchers.not(currentIn)); + InputStream lastCurrentIn = test.currentIn; + // second reset + test.reset(); + verify(lastCurrentIn).close(); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(test.currentIn, times(2)).reset(); + // assert the "nextComponent" argument + nextComponentArg.set(currentIn); + // read more, moving on to a new component + for (int i = 0; i < randomIntBetween(4, 16) || test.currentIn == currentIn; i++) { + readNBytes(test, randomInt(63)); + } + // read until the end + chainingInputStreamEOF.set(true); + readAllBytes(test); + // current component is at the end + assertThat(test.currentIn, Matchers.is(ChainingInputStream.EXHAUSTED_MARKER)); + // mark is still put + assertThat(test.markIn, Matchers.is(currentIn)); + verify(test.markIn, never()).close(); + // reset when stream is at the end + test.reset(); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(test.currentIn, times(3)).reset(); + // assert the "nextComponent" argument + nextComponentArg.set(currentIn); + // read more to verify that current component is passed as nextComponent argument + readAllBytes(test); + } + + public void testResetForDoubleMarkAnywhere() throws Exception { + Supplier mockInputStreamSupplier = () -> { + InputStream mockIn = mock(InputStream.class); + when(mockIn.markSupported()).thenReturn(true); + try { + // single byte read never returns "-1" so it never advances component + when(mockIn.read()).thenAnswer(invocationOnMock -> randomInt(255)); + when(mockIn.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())) + .thenAnswer(invocationOnMock -> { + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + if (randomBoolean()) { + return -1; + } else { + // partial read return + return randomIntBetween(1, len); + } + } + }); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return mockIn; + }; + AtomicBoolean chainingInputStreamEOF = new AtomicBoolean(false); + AtomicReference nextComponentArg = new AtomicReference<>(); + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (nextComponentArg.get() != null) { + Assert.assertThat(currentComponentIn, Matchers.is(nextComponentArg.get())); + nextComponentArg.set(null); + } + if (chainingInputStreamEOF.get()) { + return null; + } else { + return mockInputStreamSupplier.get(); + } + } + }; + // possibly skips over several components + for (int i = 0; i < randomIntBetween(4, 16); i++) { + readNBytes(test, randomInt(63)); + } + InputStream currentIn = test.currentIn; + int readLimit = randomInt(63); + // first mark + test.mark(readLimit); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + verify(currentIn).mark(Mockito.eq(readLimit)); + // possibly skips over several components + for (int i = 0; i < randomIntBetween(1, 2); i++) { + readNBytes(test, randomInt(63)); + } + InputStream lastCurrentIn = test.currentIn; + // second mark + readLimit = randomInt(63); + test.mark(readLimit); + if (lastCurrentIn != currentIn) { + verify(currentIn).close(); + } + assertThat(test.currentIn, Matchers.is(lastCurrentIn)); + assertThat(test.markIn, Matchers.is(lastCurrentIn)); + verify(lastCurrentIn).mark(Mockito.eq(readLimit)); + currentIn = lastCurrentIn; + // possibly skips over several components + for (int i = 0; i < randomIntBetween(1, 2); i++) { + readNBytes(test, randomInt(63)); + } + lastCurrentIn = test.currentIn; + // reset + test.reset(); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + if (lastCurrentIn != currentIn) { + verify(lastCurrentIn).close(); + } + verify(currentIn).reset(); + // assert the "nextComponet" arg is the current component + nextComponentArg.set(currentIn); + // possibly skips over several components + for (int i = 0; i < randomIntBetween(4, 16); i++) { + readNBytes(test, randomInt(63)); + } + lastCurrentIn = test.currentIn; + // third mark after reset + readLimit = randomInt(63); + test.mark(readLimit); + if (lastCurrentIn != currentIn) { + verify(currentIn).close(); + } + assertThat(test.currentIn, Matchers.is(lastCurrentIn)); + assertThat(test.markIn, Matchers.is(lastCurrentIn)); + verify(lastCurrentIn).mark(Mockito.eq(readLimit)); + nextComponentArg.set(lastCurrentIn); + currentIn = lastCurrentIn; + // possibly skips over several components + for (int i = 0; i < randomIntBetween(1, 2); i++) { + readNBytes(test, randomInt(63)); + } + lastCurrentIn = test.currentIn; + // reset after mark after reset + test.reset(); + assertThat(test.currentIn, Matchers.is(currentIn)); + assertThat(test.markIn, Matchers.is(currentIn)); + if (lastCurrentIn != currentIn) { + verify(lastCurrentIn).close(); + } + verify(currentIn).reset(); + } + + public void testMarkAfterResetNoMock() throws Exception { + int len = randomIntBetween(8, 15); + byte[] b = randomByteArrayOfLength(len); + for (int p = 0; p <= len; p++) { + for (int mark1 = 0; mark1 < len; mark1++) { + for (int offset1 = 0; offset1 < len - mark1; offset1++) { + for (int mark2 = 0; mark2 < len - mark1; mark2++) { + for (int offset2 = 0; offset2 < len - mark1 - mark2; offset2++) { + final int pivot = p; + ChainingInputStream test = new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (currentComponentIn == null) { + return new TestInputStream(b, 0, pivot, 1); + } else if (((TestInputStream) currentComponentIn).label == 1) { + return new TestInputStream(b, pivot, len - pivot, 2); + } else if (((TestInputStream) currentComponentIn).label == 2) { + return null; + } else { + throw new IllegalStateException(); + } + } + }; + // read "mark1" bytes + byte[] pre = readNBytes(test, mark1); + for (int i = 0; i < pre.length; i++) { + assertThat(pre[i], Matchers.is(b[i])); + } + // first mark + test.mark(len); + // read "offset" bytes + byte[] span1 = readNBytes(test, offset1); + for (int i = 0; i < span1.length; i++) { + assertThat(span1[i], Matchers.is(b[mark1 + i])); + } + // reset back to "mark1" offset + test.reset(); + // read/replay "mark2" bytes + byte[] span2 = readNBytes(test, mark2); + for (int i = 0; i < span2.length; i++) { + assertThat(span2[i], Matchers.is(b[mark1 + i])); + } + // second mark + test.mark(len); + byte[] span3 = readNBytes(test, offset2); + for (int i = 0; i < span3.length; i++) { + assertThat(span3[i], Matchers.is(b[mark1 + mark2 + i])); + } + // reset to second mark + test.reset(); + // read rest of bytes + byte[] span4 = readAllBytes(test); + for (int i = 0; i < span4.length; i++) { + assertThat(span4[i], Matchers.is(b[mark1 + mark2 + i])); + } + } + } + } + } + } + } + + private byte[] concatenateArrays(byte[] b1, byte[] b2) { + byte[] result = new byte[b1.length + b2.length]; + System.arraycopy(b1, 0, result, 0, b1.length); + System.arraycopy(b2, 0, result, b1.length, b2.length); + return result; + } + + private Tuple testEmptyComponentsInChain(int componentCount, List emptyComponentIndices) + throws Exception { + byte[] result = new byte[0]; + InputStream[] sourceComponents = new InputStream[componentCount]; + for (int i = 0; i < componentCount; i++) { + if (emptyComponentIndices.contains(i)) { + sourceComponents[i] = new ByteArrayInputStream(new byte[0]); + } else { + byte[] b = randomByteArrayOfLength(randomIntBetween(1, 8)); + sourceComponents[i] = new ByteArrayInputStream(b); + result = concatenateArrays(result, b); + } + } + return new Tuple<>(new ChainingInputStream() { + int i = 0; + + @Override + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (i < sourceComponents.length) { + return sourceComponents[i++]; + } else { + return null; + } + } + + @Override + public boolean markSupported() { + return false; + } + }, result); + } + + private ChainingInputStream newEmptyStream(boolean hasEmptyComponents) { + if (hasEmptyComponents) { + final Iterator iterator = Arrays.asList( + randomArray(1, 5, ByteArrayInputStream[]::new, () -> new ByteArrayInputStream(new byte[0])) + ).iterator(); + return new ChainingInputStream() { + InputStream nextComponent(InputStream currentComponentIn) throws IOException { + if (iterator.hasNext()) { + return iterator.next(); + } else { + return null; + } + } + }; + } else { + return new ChainingInputStream() { + @Override + InputStream nextComponent(InputStream currentElementIn) throws IOException { + return null; + } + }; + } + } + + static class TestInputStream extends InputStream { + + final byte[] b; + final int label; + final int len; + int i = 0; + int mark = -1; + final AtomicBoolean closed = new AtomicBoolean(false); + + TestInputStream(byte[] b) { + this(b, 0, b.length, 0); + } + + TestInputStream(byte[] b, int label) { + this(b, 0, b.length, label); + } + + TestInputStream(byte[] b, int offset, int len, int label) { + this.b = b; + this.i = offset; + this.len = len; + this.label = label; + } + + @Override + public int read() throws IOException { + if (b == null || i >= len) { + return -1; + } + return b[i++] & 0xFF; + } + + @Override + public void close() throws IOException { + closed.set(true); + } + + @Override + public void mark(int readlimit) { + this.mark = i; + } + + @Override + public void reset() { + this.i = this.mark; + } + + @Override + public boolean markSupported() { + return true; + } + + } +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/CountingInputStreamTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/CountingInputStreamTests.java new file mode 100644 index 0000000000000..043c06bede8e6 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/CountingInputStreamTests.java @@ -0,0 +1,164 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; +import org.junit.BeforeClass; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.elasticsearch.repositories.encrypted.EncryptionPacketsInputStreamTests.readAllBytes; +import static org.elasticsearch.repositories.encrypted.EncryptionPacketsInputStreamTests.readNBytes; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class CountingInputStreamTests extends ESTestCase { + + private static byte[] testArray; + + @BeforeClass + static void createTestArray() throws Exception { + testArray = new byte[32]; + for (int i = 0; i < testArray.length; i++) { + testArray[i] = (byte) i; + } + } + + public void testWrappedMarkAndClose() throws Exception { + AtomicBoolean isClosed = new AtomicBoolean(false); + InputStream mockIn = mock(InputStream.class); + doAnswer(new Answer() { + public Void answer(InvocationOnMock invocation) { + isClosed.set(true); + return null; + } + }).when(mockIn).close(); + new CountingInputStream(mockIn, true).close(); + assertThat(isClosed.get(), Matchers.is(true)); + isClosed.set(false); + new CountingInputStream(mockIn, false).close(); + assertThat(isClosed.get(), Matchers.is(false)); + when(mockIn.markSupported()).thenAnswer(invocationOnMock -> { return false; }); + assertThat(new CountingInputStream(mockIn, randomBoolean()).markSupported(), Matchers.is(false)); + when(mockIn.markSupported()).thenAnswer(invocationOnMock -> { return true; }); + assertThat(new CountingInputStream(mockIn, randomBoolean()).markSupported(), Matchers.is(true)); + } + + public void testSimpleCountForRead() throws Exception { + CountingInputStream test = new CountingInputStream(new ByteArrayInputStream(testArray), randomBoolean()); + assertThat(test.getCount(), Matchers.is(0L)); + int readLen = Randomness.get().nextInt(testArray.length); + readNBytes(test, readLen); + assertThat(test.getCount(), Matchers.is((long) readLen)); + readLen = testArray.length - readLen; + readNBytes(test, readLen); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + test.close(); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + } + + public void testSimpleCountForSkip() throws Exception { + CountingInputStream test = new CountingInputStream(new ByteArrayInputStream(testArray), randomBoolean()); + assertThat(test.getCount(), Matchers.is(0L)); + int skipLen = Randomness.get().nextInt(testArray.length); + test.skip(skipLen); + assertThat(test.getCount(), Matchers.is((long) skipLen)); + skipLen = testArray.length - skipLen; + readNBytes(test, skipLen); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + test.close(); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + } + + public void testCountingForMarkAndReset() throws Exception { + CountingInputStream test = new CountingInputStream(new ByteArrayInputStream(testArray), randomBoolean()); + assertThat(test.getCount(), Matchers.is(0L)); + assertThat(test.markSupported(), Matchers.is(true)); + int offset1 = Randomness.get().nextInt(testArray.length - 1); + if (randomBoolean()) { + test.skip(offset1); + } else { + test.read(new byte[offset1]); + } + assertThat(test.getCount(), Matchers.is((long) offset1)); + test.mark(testArray.length); + int offset2 = 1 + Randomness.get().nextInt(testArray.length - offset1 - 1); + if (randomBoolean()) { + test.skip(offset2); + } else { + test.read(new byte[offset2]); + } + assertThat(test.getCount(), Matchers.is((long) offset1 + offset2)); + test.reset(); + assertThat(test.getCount(), Matchers.is((long) offset1)); + int offset3 = Randomness.get().nextInt(offset2); + if (randomBoolean()) { + test.skip(offset3); + } else { + test.read(new byte[offset3]); + } + assertThat(test.getCount(), Matchers.is((long) offset1 + offset3)); + test.reset(); + assertThat(test.getCount(), Matchers.is((long) offset1)); + readAllBytes(test); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + test.close(); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + } + + public void testCountingForMarkAfterReset() throws Exception { + CountingInputStream test = new CountingInputStream(new ByteArrayInputStream(testArray), randomBoolean()); + assertThat(test.getCount(), Matchers.is(0L)); + assertThat(test.markSupported(), Matchers.is(true)); + int offset1 = Randomness.get().nextInt(testArray.length - 1); + if (randomBoolean()) { + test.skip(offset1); + } else { + test.read(new byte[offset1]); + } + assertThat(test.getCount(), Matchers.is((long) offset1)); + test.mark(testArray.length); + int offset2 = 1 + Randomness.get().nextInt(testArray.length - offset1 - 1); + if (randomBoolean()) { + test.skip(offset2); + } else { + test.read(new byte[offset2]); + } + assertThat(test.getCount(), Matchers.is((long) offset1 + offset2)); + test.reset(); + assertThat(test.getCount(), Matchers.is((long) offset1)); + int offset3 = Randomness.get().nextInt(offset2); + if (randomBoolean()) { + test.skip(offset3); + } else { + test.read(new byte[offset3]); + } + test.mark(testArray.length); + assertThat(test.getCount(), Matchers.is((long) offset1 + offset3)); + int offset4 = Randomness.get().nextInt(testArray.length - offset1 - offset3); + if (randomBoolean()) { + test.skip(offset4); + } else { + test.read(new byte[offset4]); + } + assertThat(test.getCount(), Matchers.is((long) offset1 + offset3 + offset4)); + test.reset(); + assertThat(test.getCount(), Matchers.is((long) offset1 + offset3)); + readAllBytes(test); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + test.close(); + assertThat(test.getCount(), Matchers.is((long) testArray.length)); + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/DecryptionPacketsInputStreamTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/DecryptionPacketsInputStreamTests.java new file mode 100644 index 0000000000000..fc3f0a202b221 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/DecryptionPacketsInputStreamTests.java @@ -0,0 +1,200 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; + +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; +import java.io.ByteArrayInputStream; +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.security.SecureRandom; +import java.util.Arrays; + +import static org.elasticsearch.repositories.encrypted.EncryptionPacketsInputStreamTests.readAllBytes; + +public class DecryptionPacketsInputStreamTests extends ESTestCase { + + public void testSuccessEncryptAndDecryptSmallPacketLength() throws Exception { + int len = 8 + Randomness.get().nextInt(8); + byte[] plainBytes = new byte[len]; + Randomness.get().nextBytes(plainBytes); + SecretKey secretKey = generateSecretKey(); + int nonce = Randomness.get().nextInt(); + for (int packetLen : Arrays.asList(1, 2, 3, 4)) { + testEncryptAndDecryptSuccess(plainBytes, secretKey, nonce, packetLen); + } + } + + public void testSuccessEncryptAndDecryptLargePacketLength() throws Exception { + int len = 256 + Randomness.get().nextInt(256); + byte[] plainBytes = new byte[len]; + Randomness.get().nextBytes(plainBytes); + SecretKey secretKey = generateSecretKey(); + int nonce = Randomness.get().nextInt(); + for (int packetLen : Arrays.asList(len - 1, len - 2, len - 3, len - 4)) { + testEncryptAndDecryptSuccess(plainBytes, secretKey, nonce, packetLen); + } + } + + public void testSuccessEncryptAndDecryptTypicalPacketLength() throws Exception { + int len = 1024 + Randomness.get().nextInt(512); + byte[] plainBytes = new byte[len]; + Randomness.get().nextBytes(plainBytes); + SecretKey secretKey = generateSecretKey(); + int nonce = Randomness.get().nextInt(); + for (int packetLen : Arrays.asList(128, 256, 512)) { + testEncryptAndDecryptSuccess(plainBytes, secretKey, nonce, packetLen); + } + } + + public void testFailureEncryptAndDecryptWrongKey() throws Exception { + int len = 256 + Randomness.get().nextInt(256); + // 2-3 packets + int packetLen = 1 + Randomness.get().nextInt(len / 2); + byte[] plainBytes = new byte[len]; + Randomness.get().nextBytes(plainBytes); + SecretKey encryptSecretKey = generateSecretKey(); + SecretKey decryptSecretKey = generateSecretKey(); + int nonce = Randomness.get().nextInt(); + byte[] encryptedBytes; + try ( + InputStream in = new EncryptionPacketsInputStream( + new ByteArrayInputStream(plainBytes, 0, len), + encryptSecretKey, + nonce, + packetLen + ) + ) { + encryptedBytes = readAllBytes(in); + } + try (InputStream in = new DecryptionPacketsInputStream(new ByteArrayInputStream(encryptedBytes), decryptSecretKey, packetLen)) { + IOException e = expectThrows(IOException.class, () -> { readAllBytes(in); }); + assertThat(e.getMessage(), Matchers.is("Exception during packet decryption")); + } + } + + public void testFailureEncryptAndDecryptAlteredCiphertext() throws Exception { + int len = 8 + Randomness.get().nextInt(8); + // one packet + int packetLen = len + Randomness.get().nextInt(8); + byte[] plainBytes = new byte[len]; + Randomness.get().nextBytes(plainBytes); + SecretKey secretKey = generateSecretKey(); + int nonce = Randomness.get().nextInt(); + byte[] encryptedBytes; + try (InputStream in = new EncryptionPacketsInputStream(new ByteArrayInputStream(plainBytes, 0, len), secretKey, nonce, packetLen)) { + encryptedBytes = readAllBytes(in); + } + for (int i = EncryptedRepository.GCM_IV_LENGTH_IN_BYTES; i < EncryptedRepository.GCM_IV_LENGTH_IN_BYTES + len + + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; i++) { + for (int j = 0; j < 8; j++) { + // flip bit + encryptedBytes[i] ^= (1 << j); + // fail decryption + try (InputStream in = new DecryptionPacketsInputStream(new ByteArrayInputStream(encryptedBytes), secretKey, packetLen)) { + IOException e = expectThrows(IOException.class, () -> { readAllBytes(in); }); + assertThat(e.getMessage(), Matchers.is("Exception during packet decryption")); + } + // flip bit back + encryptedBytes[i] ^= (1 << j); + } + } + } + + public void testFailureEncryptAndDecryptAlteredCiphertextIV() throws Exception { + int len = 8 + Randomness.get().nextInt(8); + int packetLen = 4 + Randomness.get().nextInt(4); + byte[] plainBytes = new byte[len]; + Randomness.get().nextBytes(plainBytes); + SecretKey secretKey = generateSecretKey(); + int nonce = Randomness.get().nextInt(); + byte[] encryptedBytes; + try (InputStream in = new EncryptionPacketsInputStream(new ByteArrayInputStream(plainBytes, 0, len), secretKey, nonce, packetLen)) { + encryptedBytes = readAllBytes(in); + } + assertThat(encryptedBytes.length, Matchers.is((int) EncryptionPacketsInputStream.getEncryptionLength(len, packetLen))); + int encryptedPacketLen = EncryptedRepository.GCM_IV_LENGTH_IN_BYTES + packetLen + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; + for (int i = 0; i < encryptedBytes.length; i += encryptedPacketLen) { + for (int j = 0; j < EncryptedRepository.GCM_IV_LENGTH_IN_BYTES; j++) { + for (int k = 0; k < 8; k++) { + // flip bit + encryptedBytes[i + j] ^= (1 << k); + try ( + InputStream in = new DecryptionPacketsInputStream(new ByteArrayInputStream(encryptedBytes), secretKey, packetLen) + ) { + IOException e = expectThrows(IOException.class, () -> { readAllBytes(in); }); + if (j < Integer.BYTES) { + assertThat(e.getMessage(), Matchers.startsWith("Exception during packet decryption")); + } else { + assertThat(e.getMessage(), Matchers.startsWith("Packet counter mismatch")); + } + } + // flip bit back + encryptedBytes[i + j] ^= (1 << k); + } + } + } + } + + private void testEncryptAndDecryptSuccess(byte[] plainBytes, SecretKey secretKey, int nonce, int packetLen) throws Exception { + for (int len = 0; len <= plainBytes.length; len++) { + byte[] encryptedBytes; + try ( + InputStream in = new EncryptionPacketsInputStream(new ByteArrayInputStream(plainBytes, 0, len), secretKey, nonce, packetLen) + ) { + encryptedBytes = readAllBytes(in); + } + assertThat((long) encryptedBytes.length, Matchers.is(EncryptionPacketsInputStream.getEncryptionLength(len, packetLen))); + byte[] decryptedBytes; + try ( + InputStream in = new DecryptionPacketsInputStream( + new ReadLessFilterInputStream(new ByteArrayInputStream(encryptedBytes)), + secretKey, + packetLen + ) + ) { + decryptedBytes = readAllBytes(in); + } + assertThat(decryptedBytes.length, Matchers.is(len)); + assertThat( + (long) decryptedBytes.length, + Matchers.is(DecryptionPacketsInputStream.getDecryptionLength(encryptedBytes.length, packetLen)) + ); + for (int i = 0; i < len; i++) { + assertThat(decryptedBytes[i], Matchers.is(plainBytes[i])); + } + } + } + + // input stream that reads less bytes than asked to, testing that packet-wide reads don't rely on `read` calls for memory buffers which + // always return the same number of bytes they are asked to + private static class ReadLessFilterInputStream extends FilterInputStream { + + protected ReadLessFilterInputStream(InputStream in) { + super(in); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (len == 0) { + return 0; + } + return super.read(b, off, randomIntBetween(1, len)); + } + } + + private SecretKey generateSecretKey() throws Exception { + KeyGenerator keyGen = KeyGenerator.getInstance("AES"); + keyGen.init(256, new SecureRandom()); + return keyGen.generateKey(); + } +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryTests.java new file mode 100644 index 0000000000000..6bb46603f7ba9 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptedRepositoryTests.java @@ -0,0 +1,175 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.cluster.metadata.RepositoryMetadata; +import org.elasticsearch.cluster.service.ClusterApplierService; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.blobstore.BlobContainer; +import org.elasticsearch.common.blobstore.BlobPath; +import org.elasticsearch.common.blobstore.BlobStore; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.indices.recovery.RecoverySettings; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.repositories.RepositoryException; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.Before; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyBoolean; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class EncryptedRepositoryTests extends ESTestCase { + + private SecureString repoPassword; + private BlobPath delegatedPath; + private BlobStore delegatedBlobStore; + private BlobStoreRepository delegatedRepository; + private RepositoryMetadata repositoryMetadata; + private EncryptedRepository encryptedRepository; + private EncryptedRepository.EncryptedBlobStore encryptedBlobStore; + private Map blobsMap; + + @Before + public void setUpMocks() throws Exception { + this.repoPassword = new SecureString(randomAlphaOfLength(20).toCharArray()); + this.delegatedPath = randomFrom( + BlobPath.cleanPath(), + BlobPath.cleanPath().add(randomAlphaOfLength(8)), + BlobPath.cleanPath().add(randomAlphaOfLength(4)).add(randomAlphaOfLength(4)) + ); + this.delegatedBlobStore = mock(BlobStore.class); + this.delegatedRepository = mock(BlobStoreRepository.class); + when(delegatedRepository.blobStore()).thenReturn(delegatedBlobStore); + when(delegatedRepository.basePath()).thenReturn(delegatedPath); + this.repositoryMetadata = new RepositoryMetadata( + randomAlphaOfLength(4), + EncryptedRepositoryPlugin.REPOSITORY_TYPE_NAME, + Settings.EMPTY + ); + ClusterApplierService clusterApplierService = mock(ClusterApplierService.class); + when(clusterApplierService.threadPool()).thenReturn(mock(ThreadPool.class)); + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.getClusterApplierService()).thenReturn(clusterApplierService); + this.encryptedRepository = new EncryptedRepository( + repositoryMetadata, + mock(NamedXContentRegistry.class), + clusterService, + mock(BigArrays.class), + mock(RecoverySettings.class), + delegatedRepository, + () -> mock(XPackLicenseState.class), + repoPassword + ); + this.encryptedBlobStore = (EncryptedRepository.EncryptedBlobStore) encryptedRepository.createBlobStore(); + this.blobsMap = new HashMap<>(); + doAnswer(invocationOnMockBlobStore -> { + BlobPath blobPath = ((BlobPath) invocationOnMockBlobStore.getArguments()[0]); + BlobContainer blobContainer = mock(BlobContainer.class); + // write atomic + doAnswer(invocationOnMockBlobContainer -> { + String DEKId = ((String) invocationOnMockBlobContainer.getArguments()[0]); + BytesReference DEKBytesReference = ((BytesReference) invocationOnMockBlobContainer.getArguments()[1]); + this.blobsMap.put(blobPath.add(DEKId), BytesReference.toBytes(DEKBytesReference)); + return null; + }).when(blobContainer).writeBlobAtomic(any(String.class), any(BytesReference.class), anyBoolean()); + // read + doAnswer(invocationOnMockBlobContainer -> { + String DEKId = ((String) invocationOnMockBlobContainer.getArguments()[0]); + return new ByteArrayInputStream(blobsMap.get(blobPath.add(DEKId))); + }).when(blobContainer).readBlob(any(String.class)); + return blobContainer; + }).when(this.delegatedBlobStore).blobContainer(any(BlobPath.class)); + } + + public void testStoreDEKSuccess() throws Exception { + String DEKId = randomAlphaOfLengthBetween(16, 32); // at least 128 bits because of FIPS + SecretKey DEK = new SecretKeySpec(randomByteArrayOfLength(32), "AES"); + + encryptedBlobStore.storeDEK(DEKId, DEK); + + Tuple KEK = encryptedRepository.generateKEK(DEKId); + assertThat(blobsMap.keySet(), contains(delegatedPath.add(EncryptedRepository.DEK_ROOT_CONTAINER).add(DEKId).add(KEK.v1()))); + byte[] wrappedKey = blobsMap.values().iterator().next(); + SecretKey unwrappedKey = AESKeyUtils.unwrap(KEK.v2(), wrappedKey); + assertThat(unwrappedKey.getEncoded(), equalTo(DEK.getEncoded())); + } + + public void testGetDEKSuccess() throws Exception { + String DEKId = randomAlphaOfLengthBetween(16, 32); // at least 128 bits because of FIPS + SecretKey DEK = new SecretKeySpec(randomByteArrayOfLength(32), "AES"); + Tuple KEK = encryptedRepository.generateKEK(DEKId); + + byte[] wrappedDEK = AESKeyUtils.wrap(KEK.v2(), DEK); + blobsMap.put(delegatedPath.add(EncryptedRepository.DEK_ROOT_CONTAINER).add(DEKId).add(KEK.v1()), wrappedDEK); + + SecretKey loadedDEK = encryptedBlobStore.getDEKById(DEKId); + assertThat(loadedDEK.getEncoded(), equalTo(DEK.getEncoded())); + } + + public void testGetTamperedDEKFails() throws Exception { + String DEKId = randomAlphaOfLengthBetween(16, 32); // at least 128 bits because of FIPS + SecretKey DEK = new SecretKeySpec("01234567890123456789012345678901".getBytes(StandardCharsets.UTF_8), "AES"); + Tuple KEK = encryptedRepository.generateKEK(DEKId); + + byte[] wrappedDEK = AESKeyUtils.wrap(KEK.v2(), DEK); + int tamperPos = randomIntBetween(0, wrappedDEK.length - 1); + wrappedDEK[tamperPos] ^= 0xFF; + blobsMap.put(delegatedPath.add(EncryptedRepository.DEK_ROOT_CONTAINER).add(DEKId).add(KEK.v1()), wrappedDEK); + + RepositoryException e = expectThrows(RepositoryException.class, () -> encryptedBlobStore.getDEKById(DEKId)); + assertThat(e.repository(), equalTo(repositoryMetadata.name())); + assertThat(e.getMessage(), containsString("Failure to AES unwrap the DEK")); + } + + public void testGetDEKIOException() { + doAnswer(invocationOnMockBlobStore -> { + BlobPath blobPath = ((BlobPath) invocationOnMockBlobStore.getArguments()[0]); + BlobContainer blobContainer = mock(BlobContainer.class); + // read + doAnswer(invocationOnMockBlobContainer -> { throw new IOException("Tested IOException"); }).when(blobContainer) + .readBlob(any(String.class)); + return blobContainer; + }).when(this.delegatedBlobStore).blobContainer(any(BlobPath.class)); + IOException e = expectThrows(IOException.class, () -> encryptedBlobStore.getDEKById("this must be at least 16")); + assertThat(e.getMessage(), containsString("Tested IOException")); + } + + public void testGenerateKEK() { + String id1 = "fixed identifier 1"; + String id2 = "fixed identifier 2"; + Tuple KEK1 = encryptedRepository.generateKEK(id1); + Tuple KEK2 = encryptedRepository.generateKEK(id2); + assertThat(KEK1.v1(), not(equalTo(KEK2.v1()))); + assertThat(KEK1.v2(), not(equalTo(KEK2.v2()))); + Tuple sameKEK1 = encryptedRepository.generateKEK(id1); + assertThat(KEK1.v1(), equalTo(sameKEK1.v1())); + assertThat(KEK1.v2(), equalTo(sameKEK1.v2())); + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptionPacketsInputStreamTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptionPacketsInputStreamTests.java new file mode 100644 index 0000000000000..4f79cb9c5764c --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/EncryptionPacketsInputStreamTests.java @@ -0,0 +1,614 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; +import org.junit.BeforeClass; +import org.mockito.Mockito; + +import javax.crypto.Cipher; +import javax.crypto.CipherInputStream; +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; +import javax.crypto.spec.GCMParameterSpec; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class EncryptionPacketsInputStreamTests extends ESTestCase { + + private static int TEST_ARRAY_SIZE = 1 << 20; + private static byte[] testPlaintextArray; + private static SecretKey secretKey; + + @BeforeClass + static void createSecretKeyAndTestArray() throws Exception { + try { + KeyGenerator keyGen = KeyGenerator.getInstance("AES"); + keyGen.init(256, new SecureRandom()); + secretKey = keyGen.generateKey(); + } catch (Exception e) { + throw new RuntimeException(e); + } + testPlaintextArray = new byte[TEST_ARRAY_SIZE]; + Randomness.get().nextBytes(testPlaintextArray); + } + + public void testEmpty() throws Exception { + int packetSize = 1 + Randomness.get().nextInt(2048); + testEncryptPacketWise(0, packetSize, new DefaultBufferedReadAllStrategy()); + } + + public void testSingleByteSize() throws Exception { + testEncryptPacketWise(1, 1, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(1, 2, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(1, 3, new DefaultBufferedReadAllStrategy()); + int packetSize = 4 + Randomness.get().nextInt(2046); + testEncryptPacketWise(1, packetSize, new DefaultBufferedReadAllStrategy()); + } + + public void testSizeSmallerThanPacketSize() throws Exception { + int packetSize = 3 + Randomness.get().nextInt(2045); + testEncryptPacketWise(packetSize - 1, packetSize, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(packetSize - 2, packetSize, new DefaultBufferedReadAllStrategy()); + int size = 1 + Randomness.get().nextInt(packetSize - 1); + testEncryptPacketWise(size, packetSize, new DefaultBufferedReadAllStrategy()); + } + + public void testSizeEqualToPacketSize() throws Exception { + int packetSize = 1 + Randomness.get().nextInt(2048); + testEncryptPacketWise(packetSize, packetSize, new DefaultBufferedReadAllStrategy()); + } + + public void testSizeLargerThanPacketSize() throws Exception { + int packetSize = 1 + Randomness.get().nextInt(2048); + testEncryptPacketWise(packetSize + 1, packetSize, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(packetSize + 2, packetSize, new DefaultBufferedReadAllStrategy()); + int size = packetSize + 3 + Randomness.get().nextInt(packetSize); + testEncryptPacketWise(size, packetSize, new DefaultBufferedReadAllStrategy()); + } + + public void testSizeMultipleOfPacketSize() throws Exception { + int packetSize = 1 + Randomness.get().nextInt(512); + testEncryptPacketWise(2 * packetSize, packetSize, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(3 * packetSize, packetSize, new DefaultBufferedReadAllStrategy()); + int packetCount = 4 + Randomness.get().nextInt(12); + testEncryptPacketWise(packetCount * packetSize, packetSize, new DefaultBufferedReadAllStrategy()); + } + + public void testSizeAlmostMultipleOfPacketSize() throws Exception { + int packetSize = 3 + Randomness.get().nextInt(510); + int packetCount = 2 + Randomness.get().nextInt(15); + testEncryptPacketWise(packetCount * packetSize - 1, packetSize, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(packetCount * packetSize - 2, packetSize, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(packetCount * packetSize + 1, packetSize, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(packetCount * packetSize + 2, packetSize, new DefaultBufferedReadAllStrategy()); + } + + public void testShortPacketSizes() throws Exception { + int packetCount = 2 + Randomness.get().nextInt(15); + testEncryptPacketWise(2 + Randomness.get().nextInt(15), 1, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(4 + Randomness.get().nextInt(30), 2, new DefaultBufferedReadAllStrategy()); + testEncryptPacketWise(6 + Randomness.get().nextInt(45), 3, new DefaultBufferedReadAllStrategy()); + } + + public void testPacketSizeMultipleOfAESBlockSize() throws Exception { + int packetSize = 1 + Randomness.get().nextInt(8); + testEncryptPacketWise( + 1 + Randomness.get().nextInt(packetSize * EncryptedRepository.AES_BLOCK_LENGTH_IN_BYTES), + packetSize * EncryptedRepository.AES_BLOCK_LENGTH_IN_BYTES, + new DefaultBufferedReadAllStrategy() + ); + testEncryptPacketWise( + packetSize * EncryptedRepository.AES_BLOCK_LENGTH_IN_BYTES + Randomness.get().nextInt(8192), + packetSize * EncryptedRepository.AES_BLOCK_LENGTH_IN_BYTES, + new DefaultBufferedReadAllStrategy() + ); + } + + public void testMarkAndResetPacketBoundaryNoMock() throws Exception { + int packetSize = 3 + Randomness.get().nextInt(512); + int size = 4 * packetSize + Randomness.get().nextInt(512); + int plaintextOffset = Randomness.get().nextInt(testPlaintextArray.length - size + 1); + int nonce = Randomness.get().nextInt(); + final byte[] referenceCiphertextArray; + try ( + InputStream encryptionInputStream = new EncryptionPacketsInputStream( + new ByteArrayInputStream(testPlaintextArray, plaintextOffset, size), + secretKey, + nonce, + packetSize + ) + ) { + referenceCiphertextArray = readAllBytes(encryptionInputStream); + } + assertThat((long) referenceCiphertextArray.length, Matchers.is(EncryptionPacketsInputStream.getEncryptionLength(size, packetSize))); + int encryptedPacketSize = packetSize + EncryptedRepository.GCM_IV_LENGTH_IN_BYTES + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; + try ( + InputStream encryptionInputStream = new EncryptionPacketsInputStream( + new ByteArrayInputStream(testPlaintextArray, plaintextOffset, size), + secretKey, + nonce, + packetSize + ) + ) { + // mark at the beginning + encryptionInputStream.mark(encryptedPacketSize - 1); + byte[] test = readNBytes(encryptionInputStream, 1 + Randomness.get().nextInt(encryptedPacketSize)); + assertSubArray(referenceCiphertextArray, 0, test, 0, test.length); + // reset at the beginning + encryptionInputStream.reset(); + // read packet fragment + test = readNBytes(encryptionInputStream, 1 + Randomness.get().nextInt(encryptedPacketSize)); + assertSubArray(referenceCiphertextArray, 0, test, 0, test.length); + // reset at the beginning + encryptionInputStream.reset(); + // read complete packet + test = readNBytes(encryptionInputStream, encryptedPacketSize); + assertSubArray(referenceCiphertextArray, 0, test, 0, test.length); + // mark at the second packet boundary + encryptionInputStream.mark(Integer.MAX_VALUE); + // read more than one packet + test = readNBytes(encryptionInputStream, encryptedPacketSize + 1 + Randomness.get().nextInt(encryptedPacketSize)); + assertSubArray(referenceCiphertextArray, encryptedPacketSize, test, 0, test.length); + // reset at the second packet boundary + encryptionInputStream.reset(); + int middlePacketOffset = Randomness.get().nextInt(encryptedPacketSize); + test = readNBytes(encryptionInputStream, middlePacketOffset); + assertSubArray(referenceCiphertextArray, encryptedPacketSize, test, 0, test.length); + // read up to the third packet boundary + test = readNBytes(encryptionInputStream, encryptedPacketSize - middlePacketOffset); + assertSubArray(referenceCiphertextArray, encryptedPacketSize + middlePacketOffset, test, 0, test.length); + // mark at the third packet boundary + encryptionInputStream.mark(Integer.MAX_VALUE); + test = readAllBytes(encryptionInputStream); + assertSubArray(referenceCiphertextArray, 2 * encryptedPacketSize, test, 0, test.length); + encryptionInputStream.reset(); + test = readNBytes( + encryptionInputStream, + 1 + Randomness.get().nextInt(referenceCiphertextArray.length - 2 * encryptedPacketSize) + ); + assertSubArray(referenceCiphertextArray, 2 * encryptedPacketSize, test, 0, test.length); + } + } + + public void testMarkResetInsidePacketNoMock() throws Exception { + int packetSize = 3 + Randomness.get().nextInt(64); + int encryptedPacketSize = EncryptedRepository.GCM_IV_LENGTH_IN_BYTES + packetSize + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; + int size = 3 * packetSize + Randomness.get().nextInt(64); + byte[] bytes = new byte[size]; + Randomness.get().nextBytes(bytes); + int nonce = Randomness.get().nextInt(); + EncryptionPacketsInputStream test = new EncryptionPacketsInputStream(new TestInputStream(bytes), secretKey, nonce, packetSize); + int offset1 = 1 + Randomness.get().nextInt(encryptedPacketSize - 1); + // read past the first packet + readNBytes(test, encryptedPacketSize + offset1); + assertThat(test.counter, Matchers.is(Long.MIN_VALUE + 2)); + assertThat(((CountingInputStream) test.currentIn).count, Matchers.is((long) offset1)); + assertThat(test.markCounter, Matchers.nullValue()); + int readLimit = 1 + Randomness.get().nextInt(packetSize); + // first mark + test.mark(readLimit); + assertThat(test.markCounter, Matchers.is(test.counter)); + assertThat(test.markSourceOnNextPacket, Matchers.is(readLimit)); + assertThat(test.markIn, Matchers.is(test.currentIn)); + assertThat(((CountingInputStream) test.markIn).mark, Matchers.is((long) offset1)); + assertThat(((TestInputStream) test.source).mark, Matchers.is(-1)); + // read before packet is complete + readNBytes(test, 1 + Randomness.get().nextInt(encryptedPacketSize - offset1)); + assertThat(((TestInputStream) test.source).mark, Matchers.is(-1)); + // reset + test.reset(); + assertThat(test.markSourceOnNextPacket, Matchers.is(readLimit)); + assertThat(test.counter, Matchers.is(test.markCounter)); + assertThat(test.currentIn, Matchers.is(test.markIn)); + assertThat(((CountingInputStream) test.currentIn).count, Matchers.is((long) offset1)); + // read before the packet is complete + int offset2 = 1 + Randomness.get().nextInt(encryptedPacketSize - offset1); + readNBytes(test, offset2); + assertThat(((TestInputStream) test.source).mark, Matchers.is(-1)); + assertThat(((CountingInputStream) test.currentIn).count, Matchers.is((long) offset1 + offset2)); + // second mark + readLimit = 1 + Randomness.get().nextInt(packetSize); + test.mark(readLimit); + assertThat(((TestInputStream) test.source).mark, Matchers.is(-1)); + assertThat(test.markCounter, Matchers.is(test.counter)); + assertThat(test.markSourceOnNextPacket, Matchers.is(readLimit)); + assertThat(test.markIn, Matchers.is(test.currentIn)); + assertThat(((CountingInputStream) test.markIn).mark, Matchers.is((long) offset1 + offset2)); + } + + public void testMarkResetAcrossPacketsNoMock() throws Exception { + int packetSize = 3 + Randomness.get().nextInt(64); + int encryptedPacketSize = EncryptedRepository.GCM_IV_LENGTH_IN_BYTES + packetSize + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; + int size = 3 * packetSize + Randomness.get().nextInt(64); + byte[] bytes = new byte[size]; + Randomness.get().nextBytes(bytes); + int nonce = Randomness.get().nextInt(); + EncryptionPacketsInputStream test = new EncryptionPacketsInputStream(new TestInputStream(bytes), secretKey, nonce, packetSize); + int readLimit = 2 * size + Randomness.get().nextInt(4096); + // mark at the beginning + test.mark(readLimit); + assertThat(test.counter, Matchers.is(Long.MIN_VALUE)); + assertThat(test.markCounter, Matchers.is(Long.MIN_VALUE)); + assertThat(test.markSourceOnNextPacket, Matchers.is(readLimit)); + assertThat(test.markIn, Matchers.nullValue()); + // read past the first packet + int offset1 = 1 + Randomness.get().nextInt(encryptedPacketSize); + readNBytes(test, encryptedPacketSize + offset1); + assertThat(test.markSourceOnNextPacket, Matchers.is(-1)); + assertThat(((TestInputStream) test.source).mark, Matchers.is(0)); + assertThat(test.counter, Matchers.is(Long.MIN_VALUE + 2)); + assertThat(((CountingInputStream) test.currentIn).count, Matchers.is((long) offset1)); + assertThat(test.markCounter, Matchers.is(Long.MIN_VALUE)); + assertThat(test.markIn, Matchers.nullValue()); + // reset at the beginning + test.reset(); + assertThat(test.markSourceOnNextPacket, Matchers.is(-1)); + assertThat(test.counter, Matchers.is(Long.MIN_VALUE)); + assertThat(test.currentIn, Matchers.nullValue()); + assertThat(((TestInputStream) test.source).off, Matchers.is(0)); + // read past the first two packets + int offset2 = 1 + Randomness.get().nextInt(encryptedPacketSize); + readNBytes(test, 2 * encryptedPacketSize + offset2); + assertThat(test.markSourceOnNextPacket, Matchers.is(-1)); + assertThat(((TestInputStream) test.source).mark, Matchers.is(0)); + assertThat(test.counter, Matchers.is(Long.MIN_VALUE + 3)); + assertThat(((CountingInputStream) test.currentIn).count, Matchers.is((long) offset2)); + assertThat(test.markCounter, Matchers.is(Long.MIN_VALUE)); + assertThat(test.markIn, Matchers.nullValue()); + // mark inside the third packet + test.mark(readLimit); + assertThat(test.markCounter, Matchers.is(Long.MIN_VALUE + 3)); + assertThat(test.markSourceOnNextPacket, Matchers.is(readLimit)); + assertThat(((CountingInputStream) test.currentIn).count, Matchers.is((long) offset2)); + assertThat(test.markIn, Matchers.is(test.currentIn)); + assertThat(((CountingInputStream) test.markIn).mark, Matchers.is((long) offset2)); + // read until the end + readAllBytes(test); + assertThat(test.markCounter, Matchers.is(Long.MIN_VALUE + 3)); + assertThat(test.counter, Matchers.not(Long.MIN_VALUE + 3)); + assertThat(test.markSourceOnNextPacket, Matchers.is(-1)); + assertThat(test.markIn, Matchers.not(test.currentIn)); + assertThat(((CountingInputStream) test.markIn).mark, Matchers.is((long) offset2)); + // reset + test.reset(); + assertThat(test.markSourceOnNextPacket, Matchers.is(-1)); + assertThat(test.counter, Matchers.is(Long.MIN_VALUE + 3)); + assertThat(((CountingInputStream) test.currentIn).count, Matchers.is((long) offset2)); + assertThat(test.markIn, Matchers.is(test.currentIn)); + } + + public void testMarkAfterResetNoMock() throws Exception { + int packetSize = 4 + Randomness.get().nextInt(4); + int plainLen = packetSize + 1 + Randomness.get().nextInt(packetSize - 1); + int plaintextOffset = Randomness.get().nextInt(testPlaintextArray.length - plainLen + 1); + int nonce = Randomness.get().nextInt(); + final byte[] referenceCiphertextArray; + try ( + InputStream encryptionInputStream = new EncryptionPacketsInputStream( + new ByteArrayInputStream(testPlaintextArray, plaintextOffset, plainLen), + secretKey, + nonce, + packetSize + ) + ) { + referenceCiphertextArray = readAllBytes(encryptionInputStream); + } + int encryptedLen = referenceCiphertextArray.length; + assertThat((long) encryptedLen, Matchers.is(EncryptionPacketsInputStream.getEncryptionLength(plainLen, packetSize))); + for (int mark1 = 0; mark1 < encryptedLen; mark1++) { + for (int offset1 = 0; offset1 < encryptedLen - mark1; offset1++) { + int mark2 = Randomness.get().nextInt(encryptedLen - mark1); + int offset2 = Randomness.get().nextInt(encryptedLen - mark1 - mark2); + EncryptionPacketsInputStream test = new EncryptionPacketsInputStream( + new ByteArrayInputStream(testPlaintextArray, plaintextOffset, plainLen), + secretKey, + nonce, + packetSize + ); + // read "mark1" bytes + byte[] pre = readNBytes(test, mark1); + for (int i = 0; i < pre.length; i++) { + assertThat(pre[i], Matchers.is(referenceCiphertextArray[i])); + } + // first mark + test.mark(encryptedLen); + // read "offset" bytes + byte[] span1 = readNBytes(test, offset1); + for (int i = 0; i < span1.length; i++) { + assertThat(span1[i], Matchers.is(referenceCiphertextArray[mark1 + i])); + } + // reset back to "mark1" offset + test.reset(); + // read/replay "mark2" bytes + byte[] span2 = readNBytes(test, mark2); + for (int i = 0; i < span2.length; i++) { + assertThat(span2[i], Matchers.is(referenceCiphertextArray[mark1 + i])); + } + // second mark + test.mark(encryptedLen); + byte[] span3 = readNBytes(test, offset2); + for (int i = 0; i < span3.length; i++) { + assertThat(span3[i], Matchers.is(referenceCiphertextArray[mark1 + mark2 + i])); + } + // reset to second mark + test.reset(); + // read rest of bytes + byte[] span4 = readAllBytes(test); + for (int i = 0; i < span4.length; i++) { + assertThat(span4[i], Matchers.is(referenceCiphertextArray[mark1 + mark2 + i])); + } + } + } + } + + public void testMark() throws Exception { + InputStream mockSource = mock(InputStream.class); + when(mockSource.markSupported()).thenAnswer(invocationOnMock -> true); + EncryptionPacketsInputStream test = new EncryptionPacketsInputStream( + mockSource, + mock(SecretKey.class), + Randomness.get().nextInt(), + 1 + Randomness.get().nextInt(32) + ); + int readLimit = 1 + Randomness.get().nextInt(4096); + InputStream mockMarkIn = mock(InputStream.class); + test.markIn = mockMarkIn; + InputStream mockCurrentIn = mock(InputStream.class); + test.currentIn = mockCurrentIn; + test.counter = Randomness.get().nextLong(); + test.markCounter = Randomness.get().nextLong(); + test.markSourceOnNextPacket = Randomness.get().nextInt(); + // mark + test.mark(readLimit); + verify(mockMarkIn).close(); + assertThat(test.markIn, Matchers.is(mockCurrentIn)); + verify(test.markIn).mark(Mockito.anyInt()); + assertThat(test.currentIn, Matchers.is(mockCurrentIn)); + assertThat(test.markCounter, Matchers.is(test.counter)); + assertThat(test.markSourceOnNextPacket, Matchers.is(readLimit)); + } + + public void testReset() throws Exception { + InputStream mockSource = mock(InputStream.class); + when(mockSource.markSupported()).thenAnswer(invocationOnMock -> true); + EncryptionPacketsInputStream test = new EncryptionPacketsInputStream( + mockSource, + mock(SecretKey.class), + Randomness.get().nextInt(), + 1 + Randomness.get().nextInt(32) + ); + InputStream mockMarkIn = mock(InputStream.class); + test.markIn = mockMarkIn; + InputStream mockCurrentIn = mock(InputStream.class); + test.currentIn = mockCurrentIn; + test.counter = Randomness.get().nextLong(); + test.markCounter = Randomness.get().nextLong(); + // source requires reset as well + test.markSourceOnNextPacket = -1; + // reset + test.reset(); + verify(mockCurrentIn).close(); + assertThat(test.currentIn, Matchers.is(mockMarkIn)); + verify(test.currentIn).reset(); + assertThat(test.markIn, Matchers.is(mockMarkIn)); + assertThat(test.counter, Matchers.is(test.markCounter)); + assertThat(test.markSourceOnNextPacket, Matchers.is(-1)); + verify(mockSource).reset(); + } + + private void testEncryptPacketWise(int size, int packetSize, ReadStrategy readStrategy) throws Exception { + int encryptedPacketSize = packetSize + EncryptedRepository.GCM_IV_LENGTH_IN_BYTES + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES; + int plaintextOffset = Randomness.get().nextInt(testPlaintextArray.length - size + 1); + int nonce = Randomness.get().nextInt(); + long counter = EncryptedRepository.PACKET_START_COUNTER; + try ( + InputStream encryptionInputStream = new EncryptionPacketsInputStream( + new ByteArrayInputStream(testPlaintextArray, plaintextOffset, size), + secretKey, + nonce, + packetSize + ) + ) { + byte[] ciphertextArray = readStrategy.readAll(encryptionInputStream); + assertThat((long) ciphertextArray.length, Matchers.is(EncryptionPacketsInputStream.getEncryptionLength(size, packetSize))); + for (int ciphertextOffset = 0; ciphertextOffset < ciphertextArray.length; ciphertextOffset += encryptedPacketSize) { + ByteBuffer ivBuffer = ByteBuffer.wrap(ciphertextArray, ciphertextOffset, EncryptedRepository.GCM_IV_LENGTH_IN_BYTES) + .order(ByteOrder.LITTLE_ENDIAN); + assertThat(ivBuffer.getInt(), Matchers.is(nonce)); + assertThat(ivBuffer.getLong(), Matchers.is(counter++)); + GCMParameterSpec gcmParameterSpec = new GCMParameterSpec( + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES * Byte.SIZE, + Arrays.copyOfRange(ciphertextArray, ciphertextOffset, ciphertextOffset + EncryptedRepository.GCM_IV_LENGTH_IN_BYTES) + ); + Cipher packetCipher = Cipher.getInstance(EncryptedRepository.DATA_ENCRYPTION_SCHEME); + packetCipher.init(Cipher.DECRYPT_MODE, secretKey, gcmParameterSpec); + try ( + InputStream packetDecryptionInputStream = new CipherInputStream( + new ByteArrayInputStream( + ciphertextArray, + ciphertextOffset + EncryptedRepository.GCM_IV_LENGTH_IN_BYTES, + packetSize + EncryptedRepository.GCM_TAG_LENGTH_IN_BYTES + ), + packetCipher + ) + ) { + byte[] decryptedCiphertext = readAllBytes(packetDecryptionInputStream); + int decryptedPacketSize = size <= packetSize ? size : packetSize; + assertThat(decryptedCiphertext.length, Matchers.is(decryptedPacketSize)); + assertSubArray(decryptedCiphertext, 0, testPlaintextArray, plaintextOffset, decryptedPacketSize); + size -= decryptedPacketSize; + plaintextOffset += decryptedPacketSize; + } + } + } + } + + private void assertSubArray(byte[] arr1, int offset1, byte[] arr2, int offset2, int length) { + // this is `Objects#checkFromIndexSize(off, len, b.length)` from Java 9 + if ((arr1.length | offset1 | length) < 0 || length > arr1.length - offset1) { + throw new IndexOutOfBoundsException( + String.format(Locale.ROOT, "Range [%d, % arr2.length - offset2) { + throw new IndexOutOfBoundsException( + String.format(Locale.ROOT, "Range [%d, %= len) { + return -1; + } + return b[off++] & 0xFF; + } + + @Override + public void close() throws IOException { + closed.set(true); + } + + @Override + public void mark(int readlimit) { + this.mark = off; + } + + @Override + public void reset() { + this.off = this.mark; + } + + @Override + public boolean markSupported() { + return true; + } + + } + + // `InputStream#readAllBytes` from Java 9 + protected static byte[] readAllBytes(InputStream inputStream) throws IOException { + return readNBytes(inputStream, Integer.MAX_VALUE); + } + + // `InputStream#readNBytes` from Java 9 + protected static byte[] readNBytes(InputStream inputStream, int len) throws IOException { + if (len < 0) { + throw new IllegalArgumentException("len < 0"); + } + + List bufs = null; + byte[] result = null; + int total = 0; + int remaining = len; + int n; + do { + byte[] buf = new byte[Math.min(remaining, 8192)]; + int nread = 0; + + // read to EOF which may read more or less than buffer size + while ((n = inputStream.read(buf, nread, Math.min(buf.length - nread, remaining))) > 0) { + nread += n; + remaining -= n; + } + + if (nread > 0) { + if (Integer.MAX_VALUE - 8 - total < nread) { + throw new OutOfMemoryError("Required array size too large"); + } + total += nread; + if (result == null) { + result = buf; + } else { + if (bufs == null) { + bufs = new ArrayList<>(); + bufs.add(result); + } + bufs.add(buf); + } + } + // if the last call to read returned -1 or the number of bytes + // requested have been read then break + } while (n >= 0 && remaining > 0); + + if (bufs == null) { + if (result == null) { + return new byte[0]; + } + return result.length == total ? result : Arrays.copyOf(result, total); + } + + result = new byte[total]; + int offset = 0; + remaining = total; + for (byte[] b : bufs) { + int count = Math.min(b.length, remaining); + System.arraycopy(b, 0, result, offset, count); + offset += count; + remaining -= count; + } + + return result; + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/LocalStateEncryptedRepositoryPlugin.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/LocalStateEncryptedRepositoryPlugin.java new file mode 100644 index 0000000000000..7e6080ceac476 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/LocalStateEncryptedRepositoryPlugin.java @@ -0,0 +1,153 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.repositories.encrypted; + +import org.apache.lucene.index.IndexCommit; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.metadata.RepositoryMetadata; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.snapshots.IndexShardSnapshotStatus; +import org.elasticsearch.index.store.Store; +import org.elasticsearch.indices.recovery.RecoverySettings; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.repositories.IndexId; +import org.elasticsearch.repositories.blobstore.BlobStoreRepository; +import org.elasticsearch.snapshots.SnapshotId; +import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin; + +import java.nio.file.Path; +import java.security.GeneralSecurityException; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Supplier; + +public final class LocalStateEncryptedRepositoryPlugin extends LocalStateCompositeXPackPlugin { + + final EncryptedRepositoryPlugin encryptedRepositoryPlugin; + + public LocalStateEncryptedRepositoryPlugin(final Settings settings, final Path configPath) { + super(settings, configPath); + final LocalStateEncryptedRepositoryPlugin thisVar = this; + + encryptedRepositoryPlugin = new EncryptedRepositoryPlugin() { + + @Override + protected XPackLicenseState getLicenseState() { + return thisVar.getLicenseState(); + } + + @Override + protected EncryptedRepository createEncryptedRepository( + RepositoryMetadata metadata, + NamedXContentRegistry registry, + ClusterService clusterService, + BigArrays bigArrays, + RecoverySettings recoverySettings, + BlobStoreRepository delegatedRepository, + Supplier licenseStateSupplier, + SecureString repoPassword + ) throws GeneralSecurityException { + return new TestEncryptedRepository( + metadata, + registry, + clusterService, + bigArrays, + recoverySettings, + delegatedRepository, + licenseStateSupplier, + repoPassword + ); + } + }; + plugins.add(encryptedRepositoryPlugin); + } + + static class TestEncryptedRepository extends EncryptedRepository { + private final Lock snapshotShardLock = new ReentrantLock(); + private final Condition snapshotShardCondition = snapshotShardLock.newCondition(); + private final AtomicBoolean snapshotShardBlock = new AtomicBoolean(false); + + TestEncryptedRepository( + RepositoryMetadata metadata, + NamedXContentRegistry registry, + ClusterService clusterService, + BigArrays bigArrays, + RecoverySettings recoverySettings, + BlobStoreRepository delegatedRepository, + Supplier licenseStateSupplier, + SecureString repoPassword + ) throws GeneralSecurityException { + super(metadata, registry, clusterService, bigArrays, recoverySettings, delegatedRepository, licenseStateSupplier, repoPassword); + } + + @Override + public void snapshotShard( + Store store, + MapperService mapperService, + SnapshotId snapshotId, + IndexId indexId, + IndexCommit snapshotIndexCommit, + String shardStateIdentifier, + IndexShardSnapshotStatus snapshotStatus, + Version repositoryMetaVersion, + Map userMetadata, + ActionListener listener + ) { + snapshotShardLock.lock(); + try { + while (snapshotShardBlock.get()) { + snapshotShardCondition.await(); + } + super.snapshotShard( + store, + mapperService, + snapshotId, + indexId, + snapshotIndexCommit, + shardStateIdentifier, + snapshotStatus, + repositoryMetaVersion, + userMetadata, + listener + ); + } catch (InterruptedException e) { + listener.onFailure(e); + } finally { + snapshotShardLock.unlock(); + } + } + + void blockSnapshotShard() { + snapshotShardLock.lock(); + try { + snapshotShardBlock.set(true); + snapshotShardCondition.signalAll(); + } finally { + snapshotShardLock.unlock(); + } + } + + void unblockSnapshotShard() { + snapshotShardLock.lock(); + try { + snapshotShardBlock.set(false); + snapshotShardCondition.signalAll(); + } finally { + snapshotShardLock.unlock(); + } + } + } + +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/PrefixInputStreamTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/PrefixInputStreamTests.java new file mode 100644 index 0000000000000..c448c54e7fef0 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/PrefixInputStreamTests.java @@ -0,0 +1,223 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.elasticsearch.repositories.encrypted.EncryptionPacketsInputStreamTests.readAllBytes; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class PrefixInputStreamTests extends ESTestCase { + + public void testZeroLength() throws Exception { + Tuple mockTuple = getMockBoundedInputStream(0); + PrefixInputStream test = new PrefixInputStream(mockTuple.v2(), 1 + Randomness.get().nextInt(32), randomBoolean()); + assertThat(test.available(), Matchers.is(0)); + assertThat(test.read(), Matchers.is(-1)); + assertThat(test.skip(1 + Randomness.get().nextInt(32)), Matchers.is(0L)); + } + + public void testClose() throws Exception { + int boundedLength = 1 + Randomness.get().nextInt(256); + Tuple mockTuple = getMockBoundedInputStream(boundedLength); + int prefixLength = Randomness.get().nextInt(boundedLength); + PrefixInputStream test = new PrefixInputStream(mockTuple.v2(), prefixLength, randomBoolean()); + test.close(); + int byteCountBefore = mockTuple.v1().get(); + IOException e = expectThrows(IOException.class, () -> { test.read(); }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + e = expectThrows(IOException.class, () -> { + byte[] b = new byte[1 + Randomness.get().nextInt(32)]; + test.read(b, 0, 1 + Randomness.get().nextInt(b.length)); + }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + e = expectThrows(IOException.class, () -> { test.skip(1 + Randomness.get().nextInt(32)); }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + e = expectThrows(IOException.class, () -> { test.available(); }); + assertThat(e.getMessage(), Matchers.is("Stream has been closed")); + int byteCountAfter = mockTuple.v1().get(); + assertThat(byteCountBefore - byteCountAfter, Matchers.is(0)); + // test closeSource parameter + AtomicBoolean isClosed = new AtomicBoolean(false); + InputStream mockIn = mock(InputStream.class); + doAnswer(new Answer() { + public Void answer(InvocationOnMock invocation) { + isClosed.set(true); + return null; + } + }).when(mockIn).close(); + new PrefixInputStream(mockIn, 1 + Randomness.get().nextInt(32), true).close(); + assertThat(isClosed.get(), Matchers.is(true)); + isClosed.set(false); + new PrefixInputStream(mockIn, 1 + Randomness.get().nextInt(32), false).close(); + assertThat(isClosed.get(), Matchers.is(false)); + } + + public void testAvailable() throws Exception { + AtomicInteger available = new AtomicInteger(0); + int boundedLength = 1 + Randomness.get().nextInt(256); + InputStream mockIn = mock(InputStream.class); + when(mockIn.available()).thenAnswer(invocationOnMock -> { return available.get(); }); + PrefixInputStream test = new PrefixInputStream(mockIn, boundedLength, randomBoolean()); + assertThat(test.available(), Matchers.is(0)); + available.set(Randomness.get().nextInt(boundedLength)); + assertThat(test.available(), Matchers.is(available.get())); + available.set(boundedLength + 1 + Randomness.get().nextInt(boundedLength)); + assertThat(test.available(), Matchers.is(boundedLength)); + } + + public void testReadPrefixLength() throws Exception { + int boundedLength = 1 + Randomness.get().nextInt(256); + Tuple mockTuple = getMockBoundedInputStream(boundedLength); + int prefixLength = Randomness.get().nextInt(boundedLength); + PrefixInputStream test = new PrefixInputStream(mockTuple.v2(), prefixLength, randomBoolean()); + int byteCountBefore = mockTuple.v1().get(); + byte[] b = readAllBytes(test); + int byteCountAfter = mockTuple.v1().get(); + assertThat(b.length, Matchers.is(prefixLength)); + assertThat(byteCountBefore - byteCountAfter, Matchers.is(prefixLength)); + assertThat(test.read(), Matchers.is(-1)); + assertThat(test.available(), Matchers.is(0)); + assertThat(mockTuple.v2().read(), Matchers.not(-1)); + } + + public void testSkipPrefixLength() throws Exception { + int boundedLength = 1 + Randomness.get().nextInt(256); + Tuple mockTuple = getMockBoundedInputStream(boundedLength); + int prefixLength = Randomness.get().nextInt(boundedLength); + PrefixInputStream test = new PrefixInputStream(mockTuple.v2(), prefixLength, randomBoolean()); + int byteCountBefore = mockTuple.v1().get(); + skipNBytes(test, prefixLength); + int byteCountAfter = mockTuple.v1().get(); + assertThat(byteCountBefore - byteCountAfter, Matchers.is(prefixLength)); + assertThat(test.read(), Matchers.is(-1)); + assertThat(test.available(), Matchers.is(0)); + assertThat(mockTuple.v2().read(), Matchers.not(-1)); + } + + public void testReadShorterWrapped() throws Exception { + int boundedLength = 1 + Randomness.get().nextInt(256); + Tuple mockTuple = getMockBoundedInputStream(boundedLength); + int prefixLength = boundedLength; + if (randomBoolean()) { + prefixLength += 1 + Randomness.get().nextInt(boundedLength); + } + PrefixInputStream test = new PrefixInputStream(mockTuple.v2(), prefixLength, randomBoolean()); + int byteCountBefore = mockTuple.v1().get(); + byte[] b = readAllBytes(test); + int byteCountAfter = mockTuple.v1().get(); + assertThat(b.length, Matchers.is(boundedLength)); + assertThat(byteCountBefore - byteCountAfter, Matchers.is(boundedLength)); + assertThat(test.read(), Matchers.is(-1)); + assertThat(test.available(), Matchers.is(0)); + assertThat(mockTuple.v2().read(), Matchers.is(-1)); + assertThat(mockTuple.v2().available(), Matchers.is(0)); + } + + public void testSkipShorterWrapped() throws Exception { + int boundedLength = 1 + Randomness.get().nextInt(256); + Tuple mockTuple = getMockBoundedInputStream(boundedLength); + final int prefixLength; + if (randomBoolean()) { + prefixLength = boundedLength + 1 + Randomness.get().nextInt(boundedLength); + } else { + prefixLength = boundedLength; + } + PrefixInputStream test = new PrefixInputStream(mockTuple.v2(), prefixLength, randomBoolean()); + int byteCountBefore = mockTuple.v1().get(); + if (prefixLength == boundedLength) { + skipNBytes(test, prefixLength); + } else { + expectThrows(EOFException.class, () -> { skipNBytes(test, prefixLength); }); + } + int byteCountAfter = mockTuple.v1().get(); + assertThat(byteCountBefore - byteCountAfter, Matchers.is(boundedLength)); + assertThat(test.read(), Matchers.is(-1)); + assertThat(test.available(), Matchers.is(0)); + assertThat(mockTuple.v2().read(), Matchers.is(-1)); + assertThat(mockTuple.v2().available(), Matchers.is(0)); + } + + private Tuple getMockBoundedInputStream(int bound) throws IOException { + InputStream mockSource = mock(InputStream.class); + AtomicInteger bytesRemaining = new AtomicInteger(bound); + when(mockSource.read(org.mockito.Matchers.any(), org.mockito.Matchers.anyInt(), org.mockito.Matchers.anyInt())).thenAnswer( + invocationOnMock -> { + final byte[] b = (byte[]) invocationOnMock.getArguments()[0]; + final int off = (int) invocationOnMock.getArguments()[1]; + final int len = (int) invocationOnMock.getArguments()[2]; + if (len == 0) { + return 0; + } else { + if (bytesRemaining.get() <= 0) { + return -1; + } + int bytesCount = 1 + Randomness.get().nextInt(Math.min(len, bytesRemaining.get())); + bytesRemaining.addAndGet(-bytesCount); + return bytesCount; + } + } + ); + when(mockSource.read()).thenAnswer(invocationOnMock -> { + if (bytesRemaining.get() <= 0) { + return -1; + } + bytesRemaining.decrementAndGet(); + return Randomness.get().nextInt(256); + }); + when(mockSource.skip(org.mockito.Matchers.anyLong())).thenAnswer(invocationOnMock -> { + final long n = (long) invocationOnMock.getArguments()[0]; + if (n <= 0 || bytesRemaining.get() <= 0) { + return 0; + } + int bytesSkipped = 1 + Randomness.get().nextInt(Math.min(bytesRemaining.get(), Math.toIntExact(n))); + bytesRemaining.addAndGet(-bytesSkipped); + return bytesSkipped; + }); + when(mockSource.available()).thenAnswer(invocationOnMock -> { + if (bytesRemaining.get() <= 0) { + return 0; + } + return 1 + Randomness.get().nextInt(bytesRemaining.get()); + }); + when(mockSource.markSupported()).thenReturn(false); + return new Tuple<>(bytesRemaining, mockSource); + } + + private static void skipNBytes(InputStream in, long n) throws IOException { + if (n > 0) { + long ns = in.skip(n); + if (ns >= 0 && ns < n) { // skipped too few bytes + // adjust number to skip + n -= ns; + // read until requested number skipped or EOS reached + while (n > 0 && in.read() != -1) { + n--; + } + // if not enough skipped, then EOFE + if (n != 0) { + throw new EOFException(); + } + } else if (ns != n) { // skipped negative or too many bytes + throw new IOException("Unable to skip exactly"); + } + } + } +} diff --git a/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/SingleUseKeyTests.java b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/SingleUseKeyTests.java new file mode 100644 index 0000000000000..034cc41a84888 --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/java/org/elasticsearch/repositories/encrypted/SingleUseKeyTests.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.repositories.encrypted; + +import org.elasticsearch.common.CheckedSupplier; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.Before; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.contains; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class SingleUseKeyTests extends ESTestCase { + + byte[] testKeyPlaintext; + SecretKey testKey; + BytesReference testKeyId; + + @Before + public void setUpMocks() { + testKeyPlaintext = randomByteArrayOfLength(32); + testKey = new SecretKeySpec(testKeyPlaintext, "AES"); + testKeyId = new BytesArray(randomAlphaOfLengthBetween(2, 32)); + } + + public void testNewKeySupplier() throws Exception { + CheckedSupplier singleUseKeySupplier = SingleUseKey.createSingleUseKeySupplier( + () -> new Tuple<>(testKeyId, testKey) + ); + SingleUseKey generatedSingleUseKey = singleUseKeySupplier.get(); + assertThat(generatedSingleUseKey.getKeyId(), equalTo(testKeyId)); + assertThat(generatedSingleUseKey.getNonce(), equalTo(SingleUseKey.MIN_NONCE)); + assertThat(generatedSingleUseKey.getKey().getEncoded(), equalTo(testKeyPlaintext)); + } + + public void testNonceIncrement() throws Exception { + int nonce = randomIntBetween(SingleUseKey.MIN_NONCE, SingleUseKey.MAX_NONCE - 2); + SingleUseKey singleUseKey = new SingleUseKey(testKeyId, testKey, nonce); + AtomicReference keyCurrentlyInUse = new AtomicReference<>(singleUseKey); + @SuppressWarnings("unchecked") + CheckedSupplier, IOException> keyGenerator = mock(CheckedSupplier.class); + CheckedSupplier singleUseKeySupplier = SingleUseKey.internalSingleUseKeySupplier( + keyGenerator, + keyCurrentlyInUse + ); + SingleUseKey generatedSingleUseKey = singleUseKeySupplier.get(); + assertThat(generatedSingleUseKey.getKeyId(), equalTo(testKeyId)); + assertThat(generatedSingleUseKey.getNonce(), equalTo(nonce)); + assertThat(generatedSingleUseKey.getKey().getEncoded(), equalTo(testKeyPlaintext)); + SingleUseKey generatedSingleUseKey2 = singleUseKeySupplier.get(); + assertThat(generatedSingleUseKey2.getKeyId(), equalTo(testKeyId)); + assertThat(generatedSingleUseKey2.getNonce(), equalTo(nonce + 1)); + assertThat(generatedSingleUseKey2.getKey().getEncoded(), equalTo(testKeyPlaintext)); + verifyZeroInteractions(keyGenerator); + } + + public void testConcurrentWrapAround() throws Exception { + int nThreads = 3; + TestThreadPool testThreadPool = new TestThreadPool( + "SingleUserKeyTests#testConcurrentWrapAround", + Settings.builder() + .put("thread_pool." + ThreadPool.Names.GENERIC + ".size", nThreads) + .put("thread_pool." + ThreadPool.Names.GENERIC + ".queue_size", 1) + .build() + ); + int nonce = SingleUseKey.MAX_NONCE; + SingleUseKey singleUseKey = new SingleUseKey(null, null, nonce); + + AtomicReference keyCurrentlyInUse = new AtomicReference<>(singleUseKey); + @SuppressWarnings("unchecked") + CheckedSupplier, IOException> keyGenerator = mock(CheckedSupplier.class); + when(keyGenerator.get()).thenReturn(new Tuple<>(testKeyId, testKey)); + CheckedSupplier singleUseKeySupplier = SingleUseKey.internalSingleUseKeySupplier( + keyGenerator, + keyCurrentlyInUse + ); + List generatedKeys = new ArrayList<>(nThreads); + for (int i = 0; i < nThreads; i++) { + generatedKeys.add(null); + } + for (int i = 0; i < nThreads; i++) { + final int resultIdx = i; + testThreadPool.generic().execute(() -> { + try { + generatedKeys.set(resultIdx, singleUseKeySupplier.get()); + } catch (IOException e) { + fail(); + } + }); + } + terminate(testThreadPool); + verify(keyGenerator, times(1)).get(); + assertThat(keyCurrentlyInUse.get().getNonce(), equalTo(SingleUseKey.MIN_NONCE + nThreads)); + assertThat(generatedKeys.stream().map(suk -> suk.getKey()).collect(Collectors.toSet()).size(), equalTo(1)); + assertThat( + generatedKeys.stream().map(suk -> suk.getKey().getEncoded()).collect(Collectors.toSet()).iterator().next(), + equalTo(testKeyPlaintext) + ); + assertThat(generatedKeys.stream().map(suk -> suk.getKeyId()).collect(Collectors.toSet()).iterator().next(), equalTo(testKeyId)); + assertThat(generatedKeys.stream().map(suk -> suk.getNonce()).collect(Collectors.toSet()).size(), equalTo(nThreads)); + assertThat( + generatedKeys.stream().map(suk -> suk.getNonce()).collect(Collectors.toSet()), + contains(SingleUseKey.MIN_NONCE, SingleUseKey.MIN_NONCE + 1, SingleUseKey.MIN_NONCE + 2) + ); + } + + public void testNonceWrapAround() throws Exception { + int nonce = SingleUseKey.MAX_NONCE; + SingleUseKey singleUseKey = new SingleUseKey(testKeyId, testKey, nonce); + AtomicReference keyCurrentlyInUse = new AtomicReference<>(singleUseKey); + byte[] newTestKeyPlaintext = randomByteArrayOfLength(32); + SecretKey newTestKey = new SecretKeySpec(newTestKeyPlaintext, "AES"); + BytesReference newTestKeyId = new BytesArray(randomAlphaOfLengthBetween(2, 32)); + CheckedSupplier singleUseKeySupplier = SingleUseKey.internalSingleUseKeySupplier( + () -> new Tuple<>(newTestKeyId, newTestKey), + keyCurrentlyInUse + ); + SingleUseKey generatedSingleUseKey = singleUseKeySupplier.get(); + assertThat(generatedSingleUseKey.getKeyId(), equalTo(newTestKeyId)); + assertThat(generatedSingleUseKey.getNonce(), equalTo(SingleUseKey.MIN_NONCE)); + assertThat(generatedSingleUseKey.getKey().getEncoded(), equalTo(newTestKeyPlaintext)); + } + + public void testGeneratorException() { + int nonce = SingleUseKey.MAX_NONCE; + SingleUseKey singleUseKey = new SingleUseKey(null, null, nonce); + AtomicReference keyCurrentlyInUse = new AtomicReference<>(singleUseKey); + CheckedSupplier singleUseKeySupplier = SingleUseKey.internalSingleUseKeySupplier( + () -> { throw new IOException("expected exception"); }, + keyCurrentlyInUse + ); + expectThrows(IOException.class, () -> singleUseKeySupplier.get()); + } +} diff --git a/x-pack/plugin/repository-encrypted/src/test/resources/rest-api-spec/test/repository_encrypted/10_basic.yml b/x-pack/plugin/repository-encrypted/src/test/resources/rest-api-spec/test/repository_encrypted/10_basic.yml new file mode 100644 index 0000000000000..858ba3e21e3ae --- /dev/null +++ b/x-pack/plugin/repository-encrypted/src/test/resources/rest-api-spec/test/repository_encrypted/10_basic.yml @@ -0,0 +1,16 @@ +# Integration tests for repository-encrypted +# +"Plugin repository-encrypted is loaded": + - skip: + reason: "contains is a newly added assertion" + features: contains + - do: + cluster.state: {} + + # Get master node id + - set: { master_node: master } + + - do: + nodes.info: {} + + - contains: { nodes.$master.plugins: { name: repository-encrypted } }