From dc16687c876ef0632ab14e30631b72bb4fa70088 Mon Sep 17 00:00:00 2001 From: Akagiwyf <11912913@mail.sustech.edu.cn> Date: Fri, 8 Apr 2022 19:20:07 +0800 Subject: [PATCH 1/6] [basicdataset] Add PennTreebank dataset --- .../ai/djl/basicdataset/nlp/PennTreebank.java | 193 ++++++++++++++++++ .../ai/djl/basicdataset/PennTreebankTest.java | 64 ++++++ .../basicdataset/penntreebank/metadata.json | 40 ++++ 3 files changed, 297 insertions(+) create mode 100644 basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebank.java create mode 100644 basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTest.java create mode 100644 basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank/metadata.json diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebank.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebank.java new file mode 100644 index 00000000000..4850fdabdac --- /dev/null +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebank.java @@ -0,0 +1,193 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.basicdataset.nlp; + +import ai.djl.Application; +import ai.djl.basicdataset.BasicDatasets; +import ai.djl.basicdataset.RawDataset; +import ai.djl.ndarray.NDManager; +import ai.djl.repository.Artifact; +import ai.djl.repository.MRL; +import ai.djl.repository.Repository; +import ai.djl.training.dataset.Batch; +import ai.djl.training.dataset.Dataset; +import ai.djl.translate.TranslateException; +import ai.djl.util.Progress; +import java.io.IOException; +import java.nio.file.Path; + +/** + * The Penn Treebank (PTB) project selected 2,499 stories from a three year Wall Street Journal + * (WSJ) collection of 98,732 stories for syntactic annotation. + */ +public class PennTreebank implements RawDataset { + + private static final String VERSION = "1.0"; + private static final String ARTIFACT_ID = "penntreebank"; + + private Dataset.Usage usage; + private Path root; + + private MRL mrl; + private boolean prepared; + + PennTreebank(Builder builder) { + this.usage = builder.usage; + mrl = builder.getMrl(); + } + + /** + * Creates a builder to build a {@link PennTreebank}. + * + * @return a new {@link PennTreebank.Builder} object + */ + public static Builder builder() { + return new Builder(); + } + /** + * Fetches an iterator that can iterate through the {@link Dataset}. This method is not + * implemented for the PennTreebank dataset because the PennTreebank dataset is not suitable for + * iteration. If the method is called, it will directly return {@code null}. + * + * @param manager the dataset to iterate through + * @return an {@link Iterable} of {@link Batch} that contains batches of data from the dataset + */ + @Override + public Iterable getData(NDManager manager) throws IOException, TranslateException { + return null; + } + + /** + * Get data from the PennTreebank dataset. This method will directly return the path of required + * dataset. + * + * @return a {@link Path} object locating the PennTreebank dataset file + */ + @Override + public Path getData() throws IOException { + prepare(null); + return root; + } + + /** + * Prepares the dataset for use with tracked progress. + * + * @param progress the progress tracker + * @throws IOException for various exceptions depending on the dataset + */ + @Override + public void prepare(Progress progress) throws IOException { + if (prepared) { + return; + } + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); + Artifact.Item item; + + switch (usage) { + case TRAIN: + item = artifact.getFiles().get("train"); + break; + case TEST: + item = artifact.getFiles().get("test"); + break; + case VALIDATION: + item = artifact.getFiles().get("valid"); + break; + default: + throw new UnsupportedOperationException("Unsupported usage type."); + } + root = mrl.getRepository().getFile(item, "").toAbsolutePath(); + prepared = true; + } + + /** A builder to construct a {@link PennTreebank} . */ + public static final class Builder { + + Repository repository; + String groupId; + String artifactId; + Dataset.Usage usage; + + /** Constructs a new builder. */ + public Builder() { + repository = BasicDatasets.REPOSITORY; + groupId = BasicDatasets.GROUP_ID; + artifactId = ARTIFACT_ID; + usage = Dataset.Usage.TRAIN; + } + + /** + * Sets the optional repository for the dataset. + * + * @param repository the new repository + * @return this builder + */ + public Builder optRepository(Repository repository) { + this.repository = repository; + return this; + } + + /** + * Sets optional groupId. + * + * @param groupId the groupId + * @return this builder + */ + public Builder optGroupId(String groupId) { + this.groupId = groupId; + return this; + } + + /** + * Sets the optional artifactId. + * + * @param artifactId the artifactId + * @return this builder + */ + public Builder optArtifactId(String artifactId) { + if (artifactId.contains(":")) { + String[] tokens = artifactId.split(":"); + groupId = tokens[0]; + this.artifactId = tokens[1]; + } else { + this.artifactId = artifactId; + } + return this; + } + + /** + * Sets the optional usage for the dataset. + * + * @param usage the usage + * @return this builder + */ + public Builder optUsage(Dataset.Usage usage) { + this.usage = usage; + return this; + } + + /** + * Builds a new {@link PennTreebank} object. + * + * @return the new {@link PennTreebank} object + */ + public PennTreebank build() { + return new PennTreebank(this); + } + + MRL getMrl() { + return repository.dataset(Application.NLP.ANY, groupId, artifactId, VERSION); + } + } +} diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTest.java new file mode 100644 index 00000000000..b3d20ad4d2e --- /dev/null +++ b/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.basicdataset; + +import ai.djl.basicdataset.nlp.PennTreebank; +import ai.djl.repository.Repository; +import ai.djl.training.dataset.Dataset; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class PennTreebankTest { + + @Test + public void testPennTreebankTrainLocal() throws IOException { + Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo"); + PennTreebank trainingSet = + PennTreebank.builder() + .optRepository(repository) + .optUsage(Dataset.Usage.TRAIN) + .build(); + Path path = trainingSet.getData(); + Assert.assertTrue(Files.isRegularFile(path)); + Assert.assertEquals(path.getFileName().toString(), "ptb.train.txt"); + } + + @Test + public void testPennTreebankTestLocal() throws IOException { + Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo"); + PennTreebank trainingSet = + PennTreebank.builder() + .optRepository(repository) + .optUsage(Dataset.Usage.TEST) + .build(); + Path path = trainingSet.getData(); + Assert.assertTrue(Files.isRegularFile(path)); + Assert.assertEquals(path.getFileName().toString(), "ptb.test.txt"); + } + + @Test + public void testPennTreebankValidationLocal() throws IOException { + Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo"); + PennTreebank trainingSet = + PennTreebank.builder() + .optRepository(repository) + .optUsage(Dataset.Usage.VALIDATION) + .build(); + Path path = trainingSet.getData(); + Assert.assertTrue(Files.isRegularFile(path)); + Assert.assertEquals(path.getFileName().toString(), "ptb.valid.txt"); + } +} diff --git a/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank/metadata.json b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank/metadata.json new file mode 100644 index 00000000000..9bb2f05cbe1 --- /dev/null +++ b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank/metadata.json @@ -0,0 +1,40 @@ +{ + "metadataVersion": "0.2", + "resourceType": "dataset", + "application": "nlp", + "groupId": "ai.djl.basicdataset", + "artifactId": "penntreebank", + "name": "penntreebank", + "description": "The Penn Treebank (PTB) project selected 2,499 stories from a three year Wall Street Journal (WSJ) collection of 98,732 stories for syntactic annotation.", + "website": "https://blog.salesforceairesearch.com/the-wikitext-long-term-dependency-language-modeling-dataset/", + "licenses": { + "license": { + "name": "LDC User Agreement for Non-Members", + "url": "https://catalog.ldc.upenn.edu/license/ldc-non-members-agreement.pdf" + } + }, + "artifacts": [ + { + "version": "1.0", + "snapshot": false, + "name": "penntreebank", + "files": { + "train":{ + "uri" : "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt", + "sha1Hash": "f9ffb014fa33bd5730e5029697ad245184f3a678", + "size": 5101618 + }, + "test":{ + "uri" : "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt", + "sha1Hash": "5c15c548b42d80bce9332b788514e6635fb0226e", + "size": 449945 + }, + "valid":{ + "uri" : "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt", + "sha1Hash": "d9f5fed6afa5e1b82cd1e3e5f5040f6852940228", + "size": 399782 + } + } + } + ] +} \ No newline at end of file From 3990dac3d742c3fd1db1021112f5fbafd4614945 Mon Sep 17 00:00:00 2001 From: Akagiwyf <11912913@mail.sustech.edu.cn> Date: Thu, 21 Apr 2022 21:27:08 +0800 Subject: [PATCH 2/6] make PennTreebankText implement TextDataset and change the name in metadata --- .../ai/djl/basicdataset/nlp/PennTreebank.java | 193 ------------------ .../basicdataset/nlp/PennTreebankText.java | 145 +++++++++++++ .../ai/djl/basicdataset/PennTreebankTest.java | 64 ------ .../basicdataset/PennTreebankTextTest.java | 127 ++++++++++++ .../metadata.json | 8 +- 5 files changed, 276 insertions(+), 261 deletions(-) delete mode 100644 basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebank.java create mode 100644 basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java delete mode 100644 basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTest.java create mode 100644 basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java rename basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/{penntreebank => penntreebank-unlabeled-processed}/metadata.json (85%) diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebank.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebank.java deleted file mode 100644 index 4850fdabdac..00000000000 --- a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebank.java +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.basicdataset.nlp; - -import ai.djl.Application; -import ai.djl.basicdataset.BasicDatasets; -import ai.djl.basicdataset.RawDataset; -import ai.djl.ndarray.NDManager; -import ai.djl.repository.Artifact; -import ai.djl.repository.MRL; -import ai.djl.repository.Repository; -import ai.djl.training.dataset.Batch; -import ai.djl.training.dataset.Dataset; -import ai.djl.translate.TranslateException; -import ai.djl.util.Progress; -import java.io.IOException; -import java.nio.file.Path; - -/** - * The Penn Treebank (PTB) project selected 2,499 stories from a three year Wall Street Journal - * (WSJ) collection of 98,732 stories for syntactic annotation. - */ -public class PennTreebank implements RawDataset { - - private static final String VERSION = "1.0"; - private static final String ARTIFACT_ID = "penntreebank"; - - private Dataset.Usage usage; - private Path root; - - private MRL mrl; - private boolean prepared; - - PennTreebank(Builder builder) { - this.usage = builder.usage; - mrl = builder.getMrl(); - } - - /** - * Creates a builder to build a {@link PennTreebank}. - * - * @return a new {@link PennTreebank.Builder} object - */ - public static Builder builder() { - return new Builder(); - } - /** - * Fetches an iterator that can iterate through the {@link Dataset}. This method is not - * implemented for the PennTreebank dataset because the PennTreebank dataset is not suitable for - * iteration. If the method is called, it will directly return {@code null}. - * - * @param manager the dataset to iterate through - * @return an {@link Iterable} of {@link Batch} that contains batches of data from the dataset - */ - @Override - public Iterable getData(NDManager manager) throws IOException, TranslateException { - return null; - } - - /** - * Get data from the PennTreebank dataset. This method will directly return the path of required - * dataset. - * - * @return a {@link Path} object locating the PennTreebank dataset file - */ - @Override - public Path getData() throws IOException { - prepare(null); - return root; - } - - /** - * Prepares the dataset for use with tracked progress. - * - * @param progress the progress tracker - * @throws IOException for various exceptions depending on the dataset - */ - @Override - public void prepare(Progress progress) throws IOException { - if (prepared) { - return; - } - Artifact artifact = mrl.getDefaultArtifact(); - mrl.prepare(artifact, progress); - Artifact.Item item; - - switch (usage) { - case TRAIN: - item = artifact.getFiles().get("train"); - break; - case TEST: - item = artifact.getFiles().get("test"); - break; - case VALIDATION: - item = artifact.getFiles().get("valid"); - break; - default: - throw new UnsupportedOperationException("Unsupported usage type."); - } - root = mrl.getRepository().getFile(item, "").toAbsolutePath(); - prepared = true; - } - - /** A builder to construct a {@link PennTreebank} . */ - public static final class Builder { - - Repository repository; - String groupId; - String artifactId; - Dataset.Usage usage; - - /** Constructs a new builder. */ - public Builder() { - repository = BasicDatasets.REPOSITORY; - groupId = BasicDatasets.GROUP_ID; - artifactId = ARTIFACT_ID; - usage = Dataset.Usage.TRAIN; - } - - /** - * Sets the optional repository for the dataset. - * - * @param repository the new repository - * @return this builder - */ - public Builder optRepository(Repository repository) { - this.repository = repository; - return this; - } - - /** - * Sets optional groupId. - * - * @param groupId the groupId - * @return this builder - */ - public Builder optGroupId(String groupId) { - this.groupId = groupId; - return this; - } - - /** - * Sets the optional artifactId. - * - * @param artifactId the artifactId - * @return this builder - */ - public Builder optArtifactId(String artifactId) { - if (artifactId.contains(":")) { - String[] tokens = artifactId.split(":"); - groupId = tokens[0]; - this.artifactId = tokens[1]; - } else { - this.artifactId = artifactId; - } - return this; - } - - /** - * Sets the optional usage for the dataset. - * - * @param usage the usage - * @return this builder - */ - public Builder optUsage(Dataset.Usage usage) { - this.usage = usage; - return this; - } - - /** - * Builds a new {@link PennTreebank} object. - * - * @return the new {@link PennTreebank} object - */ - public PennTreebank build() { - return new PennTreebank(this); - } - - MRL getMrl() { - return repository.dataset(Application.NLP.ANY, groupId, artifactId, VERSION); - } - } -} diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java new file mode 100644 index 00000000000..e62b81bbeb0 --- /dev/null +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java @@ -0,0 +1,145 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.basicdataset.nlp; +import ai.djl.Application; +import ai.djl.basicdataset.BasicDatasets; +import ai.djl.modality.nlp.embedding.EmbeddingException; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.repository.Artifact; +import ai.djl.repository.MRL; +import ai.djl.training.dataset.Dataset; +import ai.djl.training.dataset.Record; +import ai.djl.util.Progress; +import java.io.BufferedReader; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; + +/** + * The Penn Treebank (PTB) project selected 2,499 stories from a three year Wall Street Journal + * (WSJ) collection of 98,732 stories for syntactic annotation. + */ +public class PennTreebankText extends TextDataset { + + private static final String VERSION = "1.0"; + private static final String ARTIFACT_ID = "penntreebank-unlabeled-processed"; + + /** + * Creates a new instance of {@link PennTreebankText} with the given necessary + * configurations. + * + * @param builder a builder with the necessary configurations + */ + PennTreebankText(Builder builder) { + super(builder); + this.usage = builder.usage; + mrl = builder.getMrl(); + } + + /** + * Creates a builder to build a {@link PennTreebankText}. + * + * @return a new {@link PennTreebankText.Builder} object + */ + public static Builder builder() { + return new Builder(); + } + + /** {@inheritDoc} */ + @Override + public Record get(NDManager manager, long index) throws IOException { + NDList data = new NDList(); + NDList labels = null; + data.add(sourceTextData.getEmbedding(manager, index)); + return new Record(data, labels); + } + + /** {@inheritDoc} */ + @Override + protected long availableSize() { + return sourceTextData.getSize(); + } + + /** + * Prepares the dataset for use with tracked progress. + * + * @param progress the progress tracker + * @throws IOException for various exceptions depending on the dataset + */ + @Override + public void prepare(Progress progress) throws IOException, EmbeddingException { + if (prepared) { + return; + } + Artifact artifact = mrl.getDefaultArtifact(); + mrl.prepare(artifact, progress); + Artifact.Item item; + switch (usage) { + case TRAIN: + item = artifact.getFiles().get("train"); + break; + case TEST: + item = artifact.getFiles().get("test"); + break; + case VALIDATION: + item = artifact.getFiles().get("valid"); + break; + default: + throw new UnsupportedOperationException("Unsupported usage type."); + } + Path path = mrl.getRepository().getFile(item, "").toAbsolutePath(); + List lineArray = new ArrayList<>(); + try (BufferedReader reader = Files.newBufferedReader(path)) { + String row; + while ((row = reader.readLine()) != null) { + lineArray.add(row); + } + } + preprocess(lineArray,true); + prepared = true; + } + + /** A builder to construct a {@link PennTreebankText} . */ + public static class Builder extends TextDataset.Builder{ + + /** Constructs a new builder. */ + public Builder() { + repository = BasicDatasets.REPOSITORY; + groupId = BasicDatasets.GROUP_ID; + artifactId = ARTIFACT_ID; + usage = Dataset.Usage.TRAIN; + } + + /** + * Builds a new {@link PennTreebankText} object. + * + * @return the new {@link PennTreebankText} object + */ + public PennTreebankText build() { + return new PennTreebankText(this); + } + + MRL getMrl() { + return repository.dataset(Application.NLP.ANY, groupId, artifactId, VERSION); + } + + /** {@inheritDoc} */ + @Override + protected Builder self() { + return this; + } + } +} diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTest.java deleted file mode 100644 index b3d20ad4d2e..00000000000 --- a/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTest.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.basicdataset; - -import ai.djl.basicdataset.nlp.PennTreebank; -import ai.djl.repository.Repository; -import ai.djl.training.dataset.Dataset; -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; -import org.testng.Assert; -import org.testng.annotations.Test; - -public class PennTreebankTest { - - @Test - public void testPennTreebankTrainLocal() throws IOException { - Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo"); - PennTreebank trainingSet = - PennTreebank.builder() - .optRepository(repository) - .optUsage(Dataset.Usage.TRAIN) - .build(); - Path path = trainingSet.getData(); - Assert.assertTrue(Files.isRegularFile(path)); - Assert.assertEquals(path.getFileName().toString(), "ptb.train.txt"); - } - - @Test - public void testPennTreebankTestLocal() throws IOException { - Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo"); - PennTreebank trainingSet = - PennTreebank.builder() - .optRepository(repository) - .optUsage(Dataset.Usage.TEST) - .build(); - Path path = trainingSet.getData(); - Assert.assertTrue(Files.isRegularFile(path)); - Assert.assertEquals(path.getFileName().toString(), "ptb.test.txt"); - } - - @Test - public void testPennTreebankValidationLocal() throws IOException { - Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo"); - PennTreebank trainingSet = - PennTreebank.builder() - .optRepository(repository) - .optUsage(Dataset.Usage.VALIDATION) - .build(); - Path path = trainingSet.getData(); - Assert.assertTrue(Files.isRegularFile(path)); - Assert.assertEquals(path.getFileName().toString(), "ptb.valid.txt"); - } -} diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java new file mode 100644 index 00000000000..9c6c81be132 --- /dev/null +++ b/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java @@ -0,0 +1,127 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.basicdataset; + +import ai.djl.basicdataset.nlp.PennTreebankText; +import ai.djl.basicdataset.utils.TextData.Configuration; +import ai.djl.ndarray.NDManager; +import ai.djl.repository.Repository; +import ai.djl.training.dataset.Dataset; +import ai.djl.training.dataset.Record; +import ai.djl.translate.TranslateException; + +import java.io.IOException; + + +import org.testng.Assert; +import org.testng.annotations.Test; + +public class PennTreebankTextTest { + + private static final int EMBEDDING_SIZE = 15; + + + @Test + public void testPennTreebankTextTrainLocal() throws IOException, TranslateException { + Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo"); + try (NDManager manager = NDManager.newBaseManager()) { + PennTreebankText dataset = + PennTreebankText.builder() + .setSourceConfiguration( + new Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE)) + .setEmbeddingSize(EMBEDDING_SIZE)) + .setTargetConfiguration( + new Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE)) + .setEmbeddingSize(EMBEDDING_SIZE)) + .setSampling(32, true) + .optLimit(100) + .optRepository(repository) + .optUsage(Dataset.Usage.TRAIN) + .build(); + + dataset.prepare(); + Record record = dataset.get(manager, 0); + Assert.assertEquals(record.getData().get(0).getShape().dimension(),2); + Assert.assertNull(record.getLabels()); + } + } + + @Test + public void testPennTreebankTextTestLocal() throws IOException, TranslateException { + Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo"); + try (NDManager manager = NDManager.newBaseManager()) { + PennTreebankText dataset = + PennTreebankText.builder() + .setSourceConfiguration( + new Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE)) + .setEmbeddingSize(EMBEDDING_SIZE)) + .setTargetConfiguration( + new Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE)) + .setEmbeddingSize(EMBEDDING_SIZE)) + .setSampling(32, true) + .optLimit(100) + .optRepository(repository) + .optUsage(Dataset.Usage.TEST) + .build(); + + dataset.prepare(); + Record record = dataset.get(manager, 0); + Assert.assertEquals(record.getData().get(0).getShape().dimension(),2); + Assert.assertNull(record.getLabels()); + } + } + + @Test + public void testPennTreebankTextValidationLocal() throws IOException, TranslateException { + Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo"); + try (NDManager manager = NDManager.newBaseManager()) { + PennTreebankText dataset = + PennTreebankText.builder() + .setSourceConfiguration( + new Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE)) + .setEmbeddingSize(EMBEDDING_SIZE)) + .setTargetConfiguration( + new Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE)) + .setEmbeddingSize(EMBEDDING_SIZE)) + .setSampling(32, true) + .optLimit(100) + .optRepository(repository) + .optUsage(Dataset.Usage.VALIDATION) + .build(); + + dataset.prepare(); + Record record = dataset.get(manager, 0); + Assert.assertEquals(record.getData().get(0).getShape().dimension(),2); + Assert.assertNull(record.getLabels()); + } + } + +} diff --git a/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank/metadata.json b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank-unlabeled-processed/metadata.json similarity index 85% rename from basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank/metadata.json rename to basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank-unlabeled-processed/metadata.json index 9bb2f05cbe1..6ed361d659b 100644 --- a/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank/metadata.json +++ b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank-unlabeled-processed/metadata.json @@ -3,10 +3,10 @@ "resourceType": "dataset", "application": "nlp", "groupId": "ai.djl.basicdataset", - "artifactId": "penntreebank", - "name": "penntreebank", + "artifactId": "penntreebank-unlabeled-processed", + "name": "penntreebank-unlabeled-processed", "description": "The Penn Treebank (PTB) project selected 2,499 stories from a three year Wall Street Journal (WSJ) collection of 98,732 stories for syntactic annotation.", - "website": "https://blog.salesforceairesearch.com/the-wikitext-long-term-dependency-language-modeling-dataset/", + "website": "https://github.com/wojzaremba/lstm/tree/master/data", "licenses": { "license": { "name": "LDC User Agreement for Non-Members", @@ -17,7 +17,7 @@ { "version": "1.0", "snapshot": false, - "name": "penntreebank", + "name": "penntreebank-unlabeled-processed", "files": { "train":{ "uri" : "https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt", From 67940b564ec915b51c1141564aafdb9a21db8dc9 Mon Sep 17 00:00:00 2001 From: Akagiwyf <11912913@mail.sustech.edu.cn> Date: Thu, 21 Apr 2022 22:21:59 +0800 Subject: [PATCH 3/6] change the introduction of PennTreebank --- .../ai/djl/basicdataset/nlp/PennTreebankText.java | 11 ++++++----- .../ai/djl/basicdataset/PennTreebankTextTest.java | 11 +++-------- .../penntreebank-unlabeled-processed/metadata.json | 6 +++--- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java index e62b81bbeb0..909cf8ffbab 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java @@ -11,6 +11,7 @@ * and limitations under the License. */ package ai.djl.basicdataset.nlp; + import ai.djl.Application; import ai.djl.basicdataset.BasicDatasets; import ai.djl.modality.nlp.embedding.EmbeddingException; @@ -30,7 +31,8 @@ /** * The Penn Treebank (PTB) project selected 2,499 stories from a three year Wall Street Journal - * (WSJ) collection of 98,732 stories for syntactic annotation. + * (WSJ) collection of 98,732 stories for syntactic annotation. see here for details */ public class PennTreebankText extends TextDataset { @@ -38,8 +40,7 @@ public class PennTreebankText extends TextDataset { private static final String ARTIFACT_ID = "penntreebank-unlabeled-processed"; /** - * Creates a new instance of {@link PennTreebankText} with the given necessary - * configurations. + * Creates a new instance of {@link PennTreebankText} with the given necessary configurations. * * @param builder a builder with the necessary configurations */ @@ -108,12 +109,12 @@ public void prepare(Progress progress) throws IOException, EmbeddingException { lineArray.add(row); } } - preprocess(lineArray,true); + preprocess(lineArray, true); prepared = true; } /** A builder to construct a {@link PennTreebankText} . */ - public static class Builder extends TextDataset.Builder{ + public static class Builder extends TextDataset.Builder { /** Constructs a new builder. */ public Builder() { diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java index 9c6c81be132..1e920a3a4c1 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java @@ -19,10 +19,7 @@ import ai.djl.training.dataset.Dataset; import ai.djl.training.dataset.Record; import ai.djl.translate.TranslateException; - import java.io.IOException; - - import org.testng.Assert; import org.testng.annotations.Test; @@ -30,7 +27,6 @@ public class PennTreebankTextTest { private static final int EMBEDDING_SIZE = 15; - @Test public void testPennTreebankTextTrainLocal() throws IOException, TranslateException { Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo"); @@ -57,7 +53,7 @@ public void testPennTreebankTextTrainLocal() throws IOException, TranslateExcept dataset.prepare(); Record record = dataset.get(manager, 0); - Assert.assertEquals(record.getData().get(0).getShape().dimension(),2); + Assert.assertEquals(record.getData().get(0).getShape().dimension(), 2); Assert.assertNull(record.getLabels()); } } @@ -88,7 +84,7 @@ public void testPennTreebankTextTestLocal() throws IOException, TranslateExcepti dataset.prepare(); Record record = dataset.get(manager, 0); - Assert.assertEquals(record.getData().get(0).getShape().dimension(),2); + Assert.assertEquals(record.getData().get(0).getShape().dimension(), 2); Assert.assertNull(record.getLabels()); } } @@ -119,9 +115,8 @@ public void testPennTreebankTextValidationLocal() throws IOException, TranslateE dataset.prepare(); Record record = dataset.get(manager, 0); - Assert.assertEquals(record.getData().get(0).getShape().dimension(),2); + Assert.assertEquals(record.getData().get(0).getShape().dimension(), 2); Assert.assertNull(record.getLabels()); } } - } diff --git a/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank-unlabeled-processed/metadata.json b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank-unlabeled-processed/metadata.json index 6ed361d659b..2a188b78fff 100644 --- a/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank-unlabeled-processed/metadata.json +++ b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank-unlabeled-processed/metadata.json @@ -6,11 +6,11 @@ "artifactId": "penntreebank-unlabeled-processed", "name": "penntreebank-unlabeled-processed", "description": "The Penn Treebank (PTB) project selected 2,499 stories from a three year Wall Street Journal (WSJ) collection of 98,732 stories for syntactic annotation.", - "website": "https://github.com/wojzaremba/lstm/tree/master/data", + "website": "https://catalog.ldc.upenn.edu/docs/LDC95T7/cl93.html", "licenses": { "license": { - "name": "LDC User Agreement for Non-Members", - "url": "https://catalog.ldc.upenn.edu/license/ldc-non-members-agreement.pdf" + "name": "The Apache License, Version 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0" } }, "artifacts": [ From 00ac1640f14bd78c36ac0823a6faaf80d56edda4 Mon Sep 17 00:00:00 2001 From: Akagiwyf <11912913@mail.sustech.edu.cn> Date: Fri, 22 Apr 2022 10:27:28 +0800 Subject: [PATCH 4/6] Fix the bad format in the introduction in PennTreebankText --- .../main/java/ai/djl/basicdataset/nlp/PennTreebankText.java | 4 ++-- .../ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.java | 1 + .../java/ai/djl/basicdataset/AirfoilRandomAccessTest.java | 1 - 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java index 909cf8ffbab..1ef29f29873 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/PennTreebankText.java @@ -31,8 +31,8 @@ /** * The Penn Treebank (PTB) project selected 2,499 stories from a three year Wall Street Journal - * (WSJ) collection of 98,732 stories for syntactic annotation. see here for details + * (WSJ) collection of 98,732 stories for syntactic annotation (see here for details). */ public class PennTreebankText extends TextDataset { diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.java index e59f4bbf177..fd87d45fdfc 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.java @@ -86,6 +86,7 @@ public void prepare(Progress progress) throws IOException, EmbeddingException { List targetTextData = new ArrayList<>(); try (BufferedReader reader = Files.newBufferedReader(usagePath)) { String row; + while ((row = reader.readLine()) != null) { String[] text = row.split("\t"); sourceTextData.add(text[0]); diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java index dc8028c5c92..c440820faff 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java @@ -89,7 +89,6 @@ public void testAirfoilRemotePreprocessing() throws IOException, TranslateExcept .optLimit(1500) .setSampling(10, true) .build(); - airfoil.prepare(); NDManager manager = NDManager.newBaseManager(); From 1d3cec5c028f097e6ded0e1f2ab944b463901497 Mon Sep 17 00:00:00 2001 From: Zach Kimberg Date: Fri, 29 Apr 2022 11:57:31 -0700 Subject: [PATCH 5/6] Fix license, remove accidental file changes --- .../ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.java | 1 - .../java/ai/djl/basicdataset/AirfoilRandomAccessTest.java | 1 + .../penntreebank-unlabeled-processed/metadata.json | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.java b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.java index fd87d45fdfc..e59f4bbf177 100644 --- a/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.java +++ b/basicdataset/src/main/java/ai/djl/basicdataset/nlp/TatoebaEnglishFrenchDataset.java @@ -86,7 +86,6 @@ public void prepare(Progress progress) throws IOException, EmbeddingException { List targetTextData = new ArrayList<>(); try (BufferedReader reader = Files.newBufferedReader(usagePath)) { String row; - while ((row = reader.readLine()) != null) { String[] text = row.split("\t"); sourceTextData.add(text[0]); diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java index c440820faff..dc8028c5c92 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/AirfoilRandomAccessTest.java @@ -89,6 +89,7 @@ public void testAirfoilRemotePreprocessing() throws IOException, TranslateExcept .optLimit(1500) .setSampling(10, true) .build(); + airfoil.prepare(); NDManager manager = NDManager.newBaseManager(); diff --git a/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank-unlabeled-processed/metadata.json b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank-unlabeled-processed/metadata.json index 2a188b78fff..8071911c026 100644 --- a/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank-unlabeled-processed/metadata.json +++ b/basicdataset/src/test/resources/mlrepo/dataset/nlp/ai/djl/basicdataset/penntreebank-unlabeled-processed/metadata.json @@ -9,8 +9,8 @@ "website": "https://catalog.ldc.upenn.edu/docs/LDC95T7/cl93.html", "licenses": { "license": { - "name": "The Apache License, Version 2.0", - "url": "https://www.apache.org/licenses/LICENSE-2.0" + "name": "LDC User Agreement for Non-Members", + "url": "https://catalog.ldc.upenn.edu/license/ldc-non-members-agreement.pdf" } }, "artifacts": [ From 8ece361a39286cfe011cbb22757462b2980e813c Mon Sep 17 00:00:00 2001 From: Akagiwyf <11912913@mail.sustech.edu.cn> Date: Sat, 30 Apr 2022 22:32:21 +0800 Subject: [PATCH 6/6] improve the test method and make it simpler --- .../basicdataset/PennTreebankTextTest.java | 119 +++++------------- 1 file changed, 29 insertions(+), 90 deletions(-) diff --git a/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java b/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java index 1e920a3a4c1..24d7c42bc6d 100644 --- a/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java +++ b/basicdataset/src/test/java/ai/djl/basicdataset/PennTreebankTextTest.java @@ -15,7 +15,6 @@ import ai.djl.basicdataset.nlp.PennTreebankText; import ai.djl.basicdataset.utils.TextData.Configuration; import ai.djl.ndarray.NDManager; -import ai.djl.repository.Repository; import ai.djl.training.dataset.Dataset; import ai.djl.training.dataset.Record; import ai.djl.translate.TranslateException; @@ -28,95 +27,35 @@ public class PennTreebankTextTest { private static final int EMBEDDING_SIZE = 15; @Test - public void testPennTreebankTextTrainLocal() throws IOException, TranslateException { - Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo"); - try (NDManager manager = NDManager.newBaseManager()) { - PennTreebankText dataset = - PennTreebankText.builder() - .setSourceConfiguration( - new Configuration() - .setTextEmbedding( - TestUtils.getTextEmbedding( - manager, EMBEDDING_SIZE)) - .setEmbeddingSize(EMBEDDING_SIZE)) - .setTargetConfiguration( - new Configuration() - .setTextEmbedding( - TestUtils.getTextEmbedding( - manager, EMBEDDING_SIZE)) - .setEmbeddingSize(EMBEDDING_SIZE)) - .setSampling(32, true) - .optLimit(100) - .optRepository(repository) - .optUsage(Dataset.Usage.TRAIN) - .build(); - - dataset.prepare(); - Record record = dataset.get(manager, 0); - Assert.assertEquals(record.getData().get(0).getShape().dimension(), 2); - Assert.assertNull(record.getLabels()); - } - } - - @Test - public void testPennTreebankTextTestLocal() throws IOException, TranslateException { - Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo"); - try (NDManager manager = NDManager.newBaseManager()) { - PennTreebankText dataset = - PennTreebankText.builder() - .setSourceConfiguration( - new Configuration() - .setTextEmbedding( - TestUtils.getTextEmbedding( - manager, EMBEDDING_SIZE)) - .setEmbeddingSize(EMBEDDING_SIZE)) - .setTargetConfiguration( - new Configuration() - .setTextEmbedding( - TestUtils.getTextEmbedding( - manager, EMBEDDING_SIZE)) - .setEmbeddingSize(EMBEDDING_SIZE)) - .setSampling(32, true) - .optLimit(100) - .optRepository(repository) - .optUsage(Dataset.Usage.TEST) - .build(); - - dataset.prepare(); - Record record = dataset.get(manager, 0); - Assert.assertEquals(record.getData().get(0).getShape().dimension(), 2); - Assert.assertNull(record.getLabels()); - } - } - - @Test - public void testPennTreebankTextValidationLocal() throws IOException, TranslateException { - Repository repository = Repository.newInstance("test", "src/test/resources/mlrepo"); - try (NDManager manager = NDManager.newBaseManager()) { - PennTreebankText dataset = - PennTreebankText.builder() - .setSourceConfiguration( - new Configuration() - .setTextEmbedding( - TestUtils.getTextEmbedding( - manager, EMBEDDING_SIZE)) - .setEmbeddingSize(EMBEDDING_SIZE)) - .setTargetConfiguration( - new Configuration() - .setTextEmbedding( - TestUtils.getTextEmbedding( - manager, EMBEDDING_SIZE)) - .setEmbeddingSize(EMBEDDING_SIZE)) - .setSampling(32, true) - .optLimit(100) - .optRepository(repository) - .optUsage(Dataset.Usage.VALIDATION) - .build(); - - dataset.prepare(); - Record record = dataset.get(manager, 0); - Assert.assertEquals(record.getData().get(0).getShape().dimension(), 2); - Assert.assertNull(record.getLabels()); + public void testPennTreebankText() throws IOException, TranslateException { + for (Dataset.Usage usage : + new Dataset.Usage[] { + Dataset.Usage.TRAIN, Dataset.Usage.VALIDATION, Dataset.Usage.TEST + }) { + try (NDManager manager = NDManager.newBaseManager()) { + PennTreebankText dataset = + PennTreebankText.builder() + .setSourceConfiguration( + new Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE)) + .setEmbeddingSize(EMBEDDING_SIZE)) + .setTargetConfiguration( + new Configuration() + .setTextEmbedding( + TestUtils.getTextEmbedding( + manager, EMBEDDING_SIZE)) + .setEmbeddingSize(EMBEDDING_SIZE)) + .setSampling(32, true) + .optLimit(100) + .optUsage(usage) + .build(); + dataset.prepare(); + Record record = dataset.get(manager, 0); + Assert.assertEquals(record.getData().get(0).getShape().get(1), 15); + Assert.assertNull(record.getLabels()); + } } } }