Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add triton_group_norm #4

Open
wants to merge 19 commits into
base: jieru_triton
Choose a base branch
from

Conversation

nishirong
Copy link

PR types

PR changes

Describe

实现 triton group norm


tune_and_invoke_part_with_two_kernels = tune_and_invoke_part.replace("${op_name}", "${first_kernel_name}").replace("run_triton_kernel", "run_triton_first_kernel") + tune_and_invoke_part.replace("${op_name}", "${second_kernel_name}").replace("run_triton_kernel", "run_triton_second_kernel")

# tune_and_invoke_part_with_two_kernels = """
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果下面的东西的确可以由上面的语句产生,就把注释的语句都删了吧

offset_block = tl.arange(0, BLOCK_SIZE_M)
data_start = batch_id * batch_stride + group_id * group_stride
sample_ptrs = sample_ptr + data_start + offset_channel[:, None] * channel_stride + offset_block[None, :] * hw_stride + block_id * BLOCK_SIZE_M
# 计算均值
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成英文

_sum = tl.sum(sample)

_sum_squares = tl.sum(sample * sample)
output_start = batch_id * group_num + group_id + tl.arange(0,1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确定需要加上tl.arange(0,1)吗?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants