Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverse Cumulative Sum #33520

Open
lematt1991 opened this issue Feb 19, 2020 · 6 comments
Open

Reverse Cumulative Sum #33520

lematt1991 opened this issue Feb 19, 2020 · 6 comments
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@lematt1991
Copy link

🚀 Feature

Add reverse option to torch.cumsum, such as in tensorflow

Motivation

This would compute right to left cumulative sum more efficiently. Currently, as far as I know, the only way to do it is

x = torch.arange(9).view(3, 3)
r2lcumsum = th.flip(th.cumsum(th.flip(x, [1]), 1), [1])

Result should be:

tensor([[ 3,  3,  2],
        [12,  9,  5],
        [21, 15,  8]])

Pitch

Add reverse arg to native cumsum function

Alternatives

Additional context

@albanD albanD added feature A request for a proper, new feature. module: operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 21, 2020
@vadimkantorov
Copy link
Contributor

also stumbled on this. @lematt1991 thanks for the workaround!

@xuuuluuu
Copy link

also add reverse for cumprod function

@mruberry mruberry added function request A request for a new function or the addition of new arguments/modes to an existing function. and removed feature A request for a proper, new feature. module: operators (deprecated) labels Oct 10, 2020
@yufeng66
Copy link

Yes. Please add this. Using the workaround doubles the run time

@yash-s20
Copy link

yash-s20 commented Apr 3, 2021

Here's a faster workaround

import torch
x = torch.arange(9).view(3, 3)
r2lcumsum = x +  torch.sum(x, dim=1, keepdims=True) - torch.cumsum(x, dim=1)

This completely avoids the need for flip, which seems to be the bottleneck
Hope this helps someone!

Didn't check but should also work for cumprod, simply replace sum with prod, addition with multiplication, subtraction with division (unless of course dealing with zeros)

@Yonv1943
Copy link

Yonv1943 commented Nov 4, 2022

Base on the answer of @yash-s20

We don't have to calculate torch.sum(x) again, because torch.cumsum(x)[-1] == torch.sum(x).

So we have:

import torch

x = torch.arange(1, 4)
print('[1, 2, 3]', x)

cum_sum = torch.cumsum(x, dim=0)             # faster way
re_cum_sum = x - cum_sum + cum_sum[-1:None]  # faster way
print('[1, 3, 6]', cum_sum)
print('[6, 5, 3]', re_cum_sum)

re_cum_sum = torch.cumsum(x.flip(dims=(0,)), dim=0).flip(dims=(0,))  # slower way
print('[6, 5, 3]', re_cum_sum)

@rtqichen
Copy link
Contributor

rtqichen commented May 20, 2024

Small note here that the ``faster'' way of subtracting cumsum from the full sum can introduce additional roundoff errors because it relies on adding and subtracting many unwanted terms. This has consequences if what you're doing requires high numerical precision. It would still be ideal to have a native implementation of revcumsum as it would be both faster and more stable.

michalkuligowski added a commit to HabanaAI/vllm-fork that referenced this issue Sep 17, 2024
## One line description

Use topk instead of sort for topp/topk calculation under certain
conditions (scalar value of p and k).

## Details

Instead of using `k` for topk, we use `_padded_k`, which is strictly
larger than k and monotonically non decreasing.

We need/use `_padded_k > k` for cases where the smallest value of the
topk=k values has some values beyond k, (for example for
[9,8,8,8,7,7,7], with k=3, we have [9,8,8,8], which is 4 instead of 3
values),

To prevent excessive recompilations, anytime we require an expansion of
`_padded_k` we increment with a fixed constant `_increment` (usually
>1), to have a bucketed approach to prevent multiple shapes


### Basic outline

1. perform topk with `_padded_k`
2. find the "kth" value in each row (smallest number that will be in
topk), this is variable `num_duplicates_of_smallest_of_topk`
3. find maximum of number of duplicates, this variable is
`max_num_duplicates_of_smallest_of_topk`
4. check if `_padded_k` is big enough to contain
`max_num_duplicates_of_smallest_of_topk`. if not, then expand
`_padded_k`, and redo the topk again with expanded `_padded_k`
6. maskout the values that are extra in `_padded_k`
7. move to doing topp


## Perf benefit

### Using benchmark_throughput.py

To check benefit of this PR, make following change in
`benchmark_throughput.py`:
```
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index ff33e3dc..3383dea8 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -116,8 +116,9 @@ def run_vllm(
         sampling_params.append(
             SamplingParams(
                 n=n,
-                temperature=0.0 if use_beam_search else 1.0,
-                top_p=1.0,
+                temperature=1.0,  #0.0 if use_beam_search else 1.0,
+                top_p=0.95,
+                top_k=20,
                 use_beam_search=use_beam_search,
                 ignore_eos=True,
                 max_tokens=output_len,

 ```


`VLLM_SKIP_WARMUP=true VLLM_GRAPH_RESERVED_MEM=0.2 VLLM_GRAPH_PROMPT_RATIO=0.8 VLLM_DECODE_BS_BUCKET_MIN=1 VLLM_DECODE_BLOCK_BUCKET_STEP=64 VLLM_DECODE_BLOCK_BUCKET_MIN=64 python benchmark_throughput.py --model /root/sasarkar/llama3-8b/ --device hpu --seed 2024 --backend vllm --num-prompts 100 --dtype bfloat16 --input-len=256 --output-len=512`

in the numbers below there is a **49%** increase in thruput in the case with warmup, and **30%** increase in thruput in the case without warmup


#### with opt + warmup

Processed prompts: 100%|█████████████████████████████████████████████████████████████████████| 100/100 [00:22<00:00,  4.37it/s, est. speed input: 1119.66 toks/s, output: 2239.33 toks/s]
Throughput: 4.37 requests/s, 3354.58 tokens/s


#### with opt + skip warmup

Processed prompts: 100%|██████████████████████████████████████████████████████████████████████| 100/100 [00:46<00:00,  2.17it/s, est. speed input: 556.32 toks/s, output: 1112.63 toks/s]
Throughput: 2.17 requests/s, 1667.89 tokens/s


#### without opt + warmup

Processed prompts: 100%|██████████████████████████████████████████████████████████████████████| 100/100 [00:34<00:00,  2.93it/s, est. speed input: 749.24 toks/s, output: 1498.48 toks/s]
Throughput: 2.92 requests/s, 2245.74 tokens/s


#### without opt + skip warmup

Processed prompts: 100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:59<00:00,  1.67it/s, est. speed input: 428.49 toks/s, output: 856.99 toks/s]
Throughput: 1.67 requests/s, 1284.85 tokens/s

### Using server Client
(Data collected by Peter)
[baseline](https://github.com/HabanaAI/vllm-fork/commits/a7763a7a76b4531ed7907549724df2949d9225bf/)
all numbers on 1.17-495
third column [branch ](https://github.com/HabanaAI/vllm-fork/commits/ae_benchmark_9_10_24/)

| model | TP | baseline HPU thruput   | baseline HPU + this PR thruput | baseline HPU + this PR + other opt | 
| -------- | ------- | ------- | ------- | ------- |
| llama3 8b | 1 | 950  | 1296    | 1306 | 
| llama3 8b | 4 | 1347  | 1969    | 2077 | 
| llama3 70b | 4 | 368  | 394    | 394 | 
| qwen 72b | 4 | 731  | 726    | 815 |


### Without delayed sampling 
On habana_main f858d43
```VLLM_GRAPH_RESERVED_MEM=0.2 VLLM_GRAPH_PROMPT_RATIO=0.8
VLLM_DECODE_BS_BUCKET_MIN=1 VLLM_DECODE_BLOCK_BUCKET_STEP=64
VLLM_DECODE_BLOCK_BUCKET_MIN=64 python benchmark_throughput.py --model
/root/sasarkar/llama3-8b/ --device hpu --seed 2024 --backend vllm
--num-prompts 100 --dtype bfloat16 --input-len=256 --output-len=512```

Without change
Throughput: 3.32 requests/s, 2550.85 tokens/s

With change:
Throughput: 5.17 requests/s, 3967.58 tokens/s




## Extra Notes
1. Works only for "scalar" case, though it might be possible to extend
the basic idea (topk instead of sort) for vector case as well. (Outline
of this is: find max k in topk vector, then perform topk using that,
etc. needs some bucketing possibly to prevent dyn shapes etc)
2. Need an additional check in `_init_sampling_tensors` to determine if
its scalar case. This has a minor perf hit. ideally if someone could
tell us that its a scalar from the top itself...
3. Some tradeoffs can be made, where we use a sufficiently large
padded_k (which is still smaller than vocab size) from the beginning,
and hope that every case lands within that bucket. Cases that wont land
are expected to be very, very rare. For example if padded_k = max(2 * k,
100) is used, and k = say 50, then we need the smallest of the topk
value to repeat 50 times with same probability, which is exceedingly
unlikely. If we trade off this mathematical improbability, then we can
do with only 1 topk op, which might be faster
4. There is a `fliplr` in the code, which could be removed, if we can
compute reverse cumsum. however the formula for reverse cumsum as
expressed [here ](pytorch/pytorch#33520), ` x
+ torch.sum(x, dim=1, keepdims=True) - torch.cumsum(x, dim=1)` is
numerically unstable, because of the addition/subtraction. It works well
enough on ints and large numbers, but not on small probability values.
5. The value of `k` affects the gains we might get from this. For
example in the expt shown above, with k=20, thruput increases from
1284.85 to 1667.89 (30% gain). But if k = 2000, instead of 20,
throughput increases from 1127.34 to 1289.26 (14% gain). Thus the gain %
might decrease with increasing k, as asymptotically topk would probably
converges to sort's performance for large k. However practically k is
pretty small.
6. For larger models, the gains may be less, as they are more device
bound probably
7. Cumsum may be taking long. Maybe try below. [Initial
try](b392ff8)
```
import torch
y = torch.tensor([[1,2,3], [4,5,6]])
mask1 = torch.tensor([[[1,0,0],[1,1,0],[1,1,1]], [[1,0,0],[1,1,0],[1,1,1]]])
torch.sum(y.unsqueeze(1)*mask1,2)
```
or
```
F.conv1d(torch.tensor([[[0,0,0,0,1,2,3,4,5]], [[0,0,0,0,6,7,8,9,10.0]]]), torch.ones([1,1,5], dtype=torch.float32))
```
FIX #xxxx (*link existing issues this PR will resolve*)

**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE
DESCRIPTION ABOVE**

---

<details>
<!-- inside this <details> section, markdown rendering does not work, so
we use raw html here. -->
<summary><b> PR Checklist (Click to Expand) </b></summary>

<p>Thank you for your contribution to vLLM! Before submitting the pull
request, please ensure the PR meets the following criteria. This helps
vLLM maintain the code quality and improve the efficiency of the review
process.</p>

<h3>PR Title and Classification</h3>
<p>Only specific types of PRs will be reviewed. The PR title is prefixed
appropriately to indicate the type of change. Please use one of the
following:</p>
<ul>
    <li><code>[Bugfix]</code> for bug fixes.</li>
<li><code>[CI/Build]</code> for build or continuous integration
improvements.</li>
<li><code>[Doc]</code> for documentation fixes and improvements.</li>
<li><code>[Model]</code> for adding a new model or improving an existing
model. Model name should appear in the title.</li>
<li><code>[Frontend]</code> For changes on the vLLM frontend (e.g.,
OpenAI API server, <code>LLM</code> class, etc.) </li>
<li><code>[Kernel]</code> for changes affecting CUDA kernels or other
compute kernels.</li>
<li><code>[Core]</code> for changes in the core vLLM logic (e.g.,
<code>LLMEngine</code>, <code>AsyncLLMEngine</code>,
<code>Scheduler</code>, etc.)</li>
<li><code>[Hardware][Vendor]</code> for hardware-specific changes.
Vendor name should appear in the prefix (e.g.,
<code>[Hardware][AMD]</code>).</li>
<li><code>[Misc]</code> for PRs that do not fit the above categories.
Please use this sparingly.</li>
</ul>
<p><strong>Note:</strong> If the PR spans more than one category, please
include all relevant prefixes.</p>

<h3>Code Quality</h3>

<p>The PR need to meet the following code quality standards:</p>

<ul>
<li>We adhere to <a
href="https://google.github.io/styleguide/pyguide.html">Google Python
style guide</a> and <a
href="https://google.github.io/styleguide/cppguide.html">Google C++
style guide</a>.</li>
<li>Pass all linter checks. Please use <a
href="https://github.com/vllm-project/vllm/blob/main/format.sh"><code>format.sh</code></a>
to format your code.</li>
<li>The code need to be well-documented to ensure future contributors
can easily understand the code.</li>
<li>Include sufficient tests to ensure the project to stay correct and
robust. This includes both unit tests and integration tests.</li>
<li>Please add documentation to <code>docs/source/</code> if the PR
modifies the user-facing behaviors of vLLM. It helps vLLM user
understand and utilize the new features or changes.</li>
</ul>

<h3>Notes for Large Changes</h3>
<p>Please keep the changes as concise as possible. For major
architectural changes (>500 LOC excluding kernel/data/config/test), we
would expect a GitHub issue (RFC) discussing the technical design and
justification. Otherwise, we will tag it with <code>rfc-required</code>
and might not go through the PR.</p>

<h3>What to Expect for the Reviews</h3>

<p>The goal of the vLLM team is to be a <i>transparent reviewing
machine</i>. We would like to make the review process transparent and
efficient and make sure no contributor feel confused or frustrated.
However, the vLLM team is small, so we need to prioritize some PRs over
others. Here is what you can expect from the review process: </p>

<ul>
<li> After the PR is submitted, the PR will be assigned to a reviewer.
Every reviewer will pick up the PRs based on their expertise and
availability.</li>
<li> After the PR is assigned, the reviewer will provide status update
every 2-3 days. If the PR is not reviewed within 7 days, please feel
free to ping the reviewer or the vLLM team.</li>
<li> After the review, the reviewer will put an <code>
action-required</code> label on the PR if there are changes required.
The contributor should address the comments and ping the reviewer to
re-review the PR.</li>
<li> Please respond to all comments within a reasonable time frame. If a
comment isn't clear or you disagree with a suggestion, feel free to ask
for clarification or discuss the suggestion.
 </li>
</ul>

<h3>Thank You</h3>

<p> Finally, thank you for taking the time to read these guidelines and
for your interest in contributing to vLLM. Your contributions make vLLM
a great tool for everyone! </p>


</details>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants