-
Notifications
You must be signed in to change notification settings - Fork 661
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
KexinFeng
committed
Jun 9, 2023
1 parent
9d7737c
commit 4490def
Showing
19 changed files
with
2,142 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
111 changes: 111 additions & 0 deletions
111
api/src/main/java/ai/djl/modality/nlp/generate/BatchTensorList.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.modality.nlp.generate; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDList; | ||
|
||
// BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration | ||
// of the | ||
// autoregressive loop. | ||
// It is a struct consisting of NDArrays, whose first dimension is batch, and also contains | ||
// sequence dimension (whose position in tensor's shape is specified by seqDimOrder). | ||
// The SeqBatcher batch operations will operate on these two dimensions. | ||
public abstract class BatchTensorList { | ||
// [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDArray pastOutputIds; | ||
|
||
// [batch, seq_past] | ||
// The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDArray pastAttentionMask; | ||
|
||
// (k, v) * numLayer, | ||
// kv: [batch, heads, seq_past, kvfeature] | ||
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow. | ||
private NDList pastKeyValues; | ||
|
||
// Sequence dimension order among all dimensions for each element in the batch list. | ||
private long[] seqDimOrder; | ||
|
||
BatchTensorList() {} | ||
|
||
BatchTensorList(NDList list, long[] seqDimOrder) { | ||
this.seqDimOrder = seqDimOrder; | ||
pastOutputIds = list.get(0); | ||
pastAttentionMask = list.get(1); | ||
pastKeyValues = list.subNDList(2); | ||
} | ||
|
||
BatchTensorList( | ||
NDArray pastOutputIds, | ||
NDArray pastAttentionMask, | ||
NDList pastKeyValues, | ||
long[] seqDimOrder) { | ||
this.pastKeyValues = pastKeyValues; | ||
this.pastOutputIds = pastOutputIds; | ||
this.pastAttentionMask = pastAttentionMask; | ||
this.seqDimOrder = seqDimOrder; | ||
} | ||
|
||
public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder); | ||
|
||
// The pastOutputIds has to be the first in the output list | ||
public abstract NDList getList(); | ||
|
||
public long[] getSeqDimOrder() { | ||
return seqDimOrder; | ||
} | ||
|
||
/** | ||
* Gets the value of the pastOutputIds. | ||
* | ||
* @return the value of pastOutputIds | ||
*/ | ||
public NDArray getPastOutputIds() { | ||
return pastOutputIds; | ||
} | ||
|
||
public void setPastOutputIds(NDArray pastOutputIds) { | ||
this.pastOutputIds = pastOutputIds; | ||
} | ||
|
||
/** | ||
* Gets the value of the pastAttentionMask. | ||
* | ||
* @return the value of pastAttentionMask | ||
*/ | ||
public NDArray getPastAttentionMask() { | ||
return pastAttentionMask; | ||
} | ||
|
||
public void setPastAttentionMask(NDArray pastAttentionMask) { | ||
this.pastAttentionMask = pastAttentionMask; | ||
} | ||
|
||
/** | ||
* Gets the value of the pastKeyValues. | ||
* | ||
* @return the value of pastKeyValues | ||
*/ | ||
public NDList getPastKeyValues() { | ||
return pastKeyValues; | ||
} | ||
|
||
public void setPastKeyValues(NDList pastKeyValues) { | ||
this.pastKeyValues = pastKeyValues; | ||
} | ||
|
||
public void setSeqDimOrder(long[] seqDimOrder) { | ||
this.seqDimOrder = seqDimOrder; | ||
} | ||
} |
76 changes: 76 additions & 0 deletions
76
api/src/main/java/ai/djl/modality/nlp/generate/CausalLMOutput.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.modality.nlp.generate; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDList; | ||
|
||
/** CausalLMOuput is used to contain multiple output of a language model. */ | ||
public class CausalLMOutput { | ||
|
||
// [batch, seq, feature] | ||
// The prob. conditional on a sequence that ends at an element in seq-dim. seq-dim-size = | ||
// |inputIds| | ||
private NDArray logits; | ||
|
||
// [batch, seq, dim] * (layers+1) -> take -1 | ||
// The vec. rep. of a sequence that ends at an element in seq-dim. seq-dim-size = |inputIds| | ||
private NDArray hiddenStates; | ||
|
||
// (k, v) * numLayer, | ||
// kv: [batch, heads, seq_past, feature] | ||
// The cache of past sequence. seq-dim-size == |seq_past| + |inputIds| | ||
private NDList pastKeyValuesList; | ||
|
||
public CausalLMOutput(NDArray logits, NDList pastKeyValues) { | ||
this.logits = logits; | ||
this.pastKeyValuesList = pastKeyValues; | ||
} | ||
|
||
public CausalLMOutput(NDArray logits, NDArray hiddenState, NDList pastKeyValueList) { | ||
this.logits = logits; | ||
this.pastKeyValuesList = pastKeyValueList; | ||
this.hiddenStates = hiddenState; | ||
} | ||
|
||
/** | ||
* Gets the value of the logits. | ||
* | ||
* @return the value of logits | ||
*/ | ||
public NDArray getLogits() { | ||
return logits; | ||
} | ||
|
||
public void setLogits(NDArray logits) { | ||
this.logits = logits; | ||
} | ||
|
||
/** | ||
* Gets the value of the allHiddenStates. | ||
* | ||
* @return the value of allHiddenStates | ||
*/ | ||
public NDArray getHiddenState() { | ||
return hiddenStates; | ||
} | ||
|
||
/** | ||
* Gets the value of the pastKeyValuesList. | ||
* | ||
* @return the value of pastKeyValuesList | ||
*/ | ||
public NDList getPastKeyValuesList() { | ||
return pastKeyValuesList; | ||
} | ||
} |
87 changes: 87 additions & 0 deletions
87
api/src/main/java/ai/djl/modality/nlp/generate/ContrastiveBatchTensorList.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.modality.nlp.generate; | ||
|
||
import ai.djl.ndarray.NDArray; | ||
import ai.djl.ndarray.NDList; | ||
|
||
class ContrastiveBatchTensorList extends BatchTensorList { | ||
// [batch, seq_past, hiddenDim] | ||
// The embed vector of the past seq. seq-dim-size = |past_seq|. Will grow. | ||
private NDArray pastHiddenStates; | ||
|
||
// [batch, vacabSize]. Only the last logits, used to recall candidate token. | ||
private NDArray logits; | ||
|
||
ContrastiveBatchTensorList(NDList list, long[] seqDimOrder) { | ||
super(list.get(0), list.get(1), list.subNDList(4), seqDimOrder); | ||
pastHiddenStates = list.get(2); | ||
logits = list.get(3); | ||
} | ||
|
||
ContrastiveBatchTensorList( | ||
NDArray pastOutputIds, | ||
NDArray pastAttentionMask, | ||
NDArray pastHiddenStates, | ||
NDArray logits, | ||
NDList pastKeyValues, | ||
long[] seqDimOrder) { | ||
super(pastOutputIds, pastAttentionMask, pastKeyValues, seqDimOrder); | ||
this.pastHiddenStates = pastHiddenStates; | ||
this.logits = logits; | ||
} | ||
|
||
public ContrastiveBatchTensorList() {} | ||
|
||
@Override | ||
public ContrastiveBatchTensorList fromList(NDList inputList, long[] seqDimOrder) { | ||
return new ContrastiveBatchTensorList(inputList, seqDimOrder); | ||
} | ||
|
||
@Override | ||
public NDList getList() { | ||
// The pastOutputIds has to be the first in the output list | ||
return new NDList( | ||
getPastOutputIds(), | ||
getPastAttentionMask(), | ||
getPastHiddenStates(), | ||
getLogits()) | ||
.addAll(getPastKeyValues()); | ||
} | ||
|
||
/** | ||
* Gets the value of the pastHiddenStates. | ||
* | ||
* @return the value of pastHiddenStates | ||
*/ | ||
public NDArray getPastHiddenStates() { | ||
return pastHiddenStates; | ||
} | ||
|
||
public void setPastHiddenStates(NDArray pastHiddenStates) { | ||
this.pastHiddenStates = pastHiddenStates; | ||
} | ||
|
||
/** | ||
* Gets the value of the logits. | ||
* | ||
* @return the value of logits | ||
*/ | ||
public NDArray getLogits() { | ||
return logits; | ||
} | ||
|
||
public void setLogits(NDArray logits) { | ||
this.logits = logits; | ||
} | ||
} |
57 changes: 57 additions & 0 deletions
57
api/src/main/java/ai/djl/modality/nlp/generate/GPTConfig.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
/* | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance | ||
* with the License. A copy of the License is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES | ||
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions | ||
* and limitations under the License. | ||
*/ | ||
package ai.djl.modality.nlp.generate; | ||
|
||
/** GPTConfig is used to store the GPT parameters used to select different versions of GPT. */ | ||
public class GPTConfig { | ||
private int numAttentionHeads; | ||
private int numLayers; | ||
private long kvDim; | ||
|
||
public GPTConfig() { | ||
numAttentionHeads = 12; | ||
numLayers = 12; | ||
kvDim = 64; | ||
} | ||
|
||
/** | ||
* Gets the value of the numAttentionHeads. | ||
* | ||
* @return the value of numAttentionHeads | ||
*/ | ||
public int getNumAttentionHeads() { | ||
return numAttentionHeads; | ||
} | ||
|
||
/** | ||
* Gets the value of the numLayers. | ||
* | ||
* @return the value of numLayers | ||
*/ | ||
public int getNumLayers() { | ||
return numLayers; | ||
} | ||
|
||
public void setNumLayers(int numLayers) { | ||
this.numLayers = numLayers; | ||
} | ||
|
||
/** | ||
* Gets the value of the kvDim. | ||
* | ||
* @return the value of kvDim | ||
*/ | ||
public long getKvDim() { | ||
return kvDim; | ||
} | ||
} |
Oops, something went wrong.