Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

开启Smart schedule时报错Segmentation fault #184

Open
Xingzhi107 opened this issue Dec 30, 2023 · 8 comments
Open

开启Smart schedule时报错Segmentation fault #184

Xingzhi107 opened this issue Dec 30, 2023 · 8 comments
Assignees
Labels
bug Something isn't working

Comments

@Xingzhi107
Copy link

Xingzhi107 commented Dec 30, 2023

When I use a custom expert, inherit the FMoE class, and turn on Smart schedule to report an error

[ubuntu:8399 :0:8399] Caught signal 11 (Segmentation fault: invalid permissions for mapped object at address 0x7f9a6d99aa00)
[ubuntu:8400 :0:8400] Caught signal 11 (Segmentation fault: invalid permissions for mapped object at address 0x7faacd99aa00)
[ubuntu:8401 :0:8401] Caught signal 11 (Segmentation fault: invalid permissions for mapped object at address 0x7f1f9b99aa00)
[ubuntu:8398 :0:8398] Caught signal 11 (Segmentation fault: invalid permissions for mapped object at address 
0x7f5defd9ac00)

The location of the error is located when using the pdb debug

 local_output_buf, gib = fmoe_native.smart_sch_forward(
                local_input_buf,
                local_expert_count, global_expert_count, 
                stored_models, fwd_batch_size, ctx.expert_size,
                world_size, _expert_forward, get_param_fn, stash_fn, pop_fn)

It's not clear to me what input and output that need to be constrained means
The input and output features have to be of the same length for the experts.
My definition goes something like this:

class Expert(nn.Module):
    def __init__(
        self,
        d_model, d_hidden,
        rank = 0,
    ):
        super().__init__()

        self.w1 = nn.Linear(
            d_model, d_hidden, bias=False
        )
        self.w2 = nn.Linear(
            d_hidden, d_model, bias=False
        )
        self.w3 = nn.Linear(
            d_model, d_hidden, bias=False
        )

    def forward(self, x, fec=None):

        print(x.shape)
        out = self.w2(F.silu(self.w1(x)) * self.w3(x))
        # print(out.shape)
        return out

class FastMoe(FMoE):
    def __init__(self,
                 num_expert=4,
                 d_model = 1024,
                 d_hidden=4096,
                 activation=torch.nn.SiLU(),
                 world_size =4,
                 top_k = 2,
        ):
        def one_expert(d_model):
            return Expert( d_model, d_hidden)
        expert = one_expert
        super().__init__(num_expert, d_model, world_size,
                         top_k=top_k,expert=expert)
        self.mark_parallel_comm()

DDP is also used:

 self.model = self.model.to(rank)
        self.model = DDP(self.model)
The terminal command is as follows:
FMOE_FASTER_SCHEDULE_ENABLE=1 FMOE_FASTER_SHADOW_ENABLE=1 FMOE_FASTER_GROUP_SIZE=1 torchrun --standalone --nproc_per_node=4  tools/example.py -m ./ckpts -t ckpts/tokenizer.model

num_expert is 1
I want to implement the parallelism of an expert on a GPU
I'd appreciate it if anyone could point me across

@laekov
Copy link
Owner

laekov commented Dec 31, 2023

It's not clear to me what input and output that need to be constrained means
The input and output features have to be of the same length for the experts.

This means the feature dimension of input / output tensors of experts should be equal. I think your code fulfills this requirement.

I am not able to reproduce your issue with your code provided using randn or ones as input data. Can you please provide more information about the error? E.g. the shape of x that you print in your expert module. You may also add -g flag to cxx_flags in setup.py, recompile fastmoe, and run the program with CUDA_LAUNCH_BLOCKING=1 to see which line of cuda code gives this error.

Also, you can try turning off some features, for example FMOE_FASTER_SHADOW_ENABLE=0 or FMOE_FASTER_GROUP_SIZE=4, and see if any of these changes can bypass the error. If so, we will be able to further inspect specific functions.

@Xingzhi107
Copy link
Author

It's not clear to me what input and output that need to be constrained means
The input and output features have to be of the same length for the experts.

This means the feature dimension of input / output tensors of experts should be equal. I think your code fulfills this requirement.

I am not able to reproduce your issue with your code provided using randn or ones as input data. Can you please provide more information about the error? E.g. the shape of x that you print in your expert module. You may also add -g flag to cxx_flags in setup.py, recompile fastmoe, and run the program with CUDA_LAUNCH_BLOCKING=1 to see which line of cuda code gives this error.

Also, you can try turning off some features, for example FMOE_FASTER_SHADOW_ENABLE=0 or FMOE_FASTER_GROUP_SIZE=4, and see if any of these changes can bypass the error. If so, we will be able to further inspect specific functions.

Thank you for your reply!
I suspected that my expert didn't write it right, so I used LinearExpert in fastmoe again,but the error is same
I used Fmoe in a transformer and my code is as follows
`class TorchTransformerBlock(nn.Module):
def init(self, layer_id: int, args: ModelArgs):

    super().__init__()
    self.n_heads = args.n_heads
    self.dim = args.dim
    self.head_dim = args.dim // args.n_heads
    self.attention = TorchAttention(args)
    self.feed_forward = TorchFFN(
        dim=args.dim,
        hidden_dim=4 * args.dim,
    )
    self.layer_id = layer_id
    self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
    self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

def forward(
    self,
    x: torch.Tensor,
    start_pos: int,
    freqs_cis: torch.Tensor,
    mask: Optional[torch.Tensor],
):
    """
    Perform a forward pass through the TransformerBlock.

    Args:
        x (torch.Tensor): Input tensor.
        start_pos (int): Starting position for attention caching.
        freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
        mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.

    Returns:
        torch.Tensor: Output tensor after applying attention and feedforward layers.

    """
    h = x + self.attention.forward(
        self.attention_norm(x), start_pos, freqs_cis, mask
    )
    norm = self.ffn_norm(h) #([5, 1, 1024])
    print(norm.shape)
    out = h + self.feed_forward.forward(norm)
    return out`

`class MoETorchTransformerBlock(TorchTransformerBlock):
def init(self, layer_id: int, args: ModelArgs):
super().init(layer_id, args)

    self.attention = TorchAttention(args)
    assert args.moe["num_experts"] % args.num_gpus == 0, "num_experts must be divisible by num_gpus"
    # print(int(os.environ['WORLD_SIZE']))
    self.feed_forward = FastMoe (
             num_expert=args.moe["num_experts"],
             d_model = args.dim,
             d_hidden=args.hidden_dim,
             activation=torch.nn.SiLU(),
             world_size =int(os.environ['WORLD_SIZE']),
             top_k = args.moe["num_experts_per_tok"],`

`class FastMoe(FMoE):
def init(self,
num_expert=4,
d_model = 1024,
d_hidden=4096,
activation=torch.nn.SiLU(),
world_size =1,
top_k = 2,
# moe_group = 1,
):
# def one_expert(d_model):
# return Expert( d_model, d_hidden)
# expert = one_expert
super().init(num_expert, d_model, world_size,
top_k=top_k,expert=LinearExpert)
# self.mark_parallel_comm()

def forward(self, inp: torch.tensor):

    original_shape = inp.shape
    #print("original_shape:",original_shape) #[bsz,seq,d]
    inp = inp.reshape(-1, self.d_model) #[bsz*seq,d]
    output = super().forward(inp)

    return output.reshape(original_shape)`

When I don't turn on smart schedule,no errors occurred,but when I add FMOE_FASTER_SCHEDULE_ENABLE=1
attention: torch.Size([5, 1, 1024]) torch.Size([5, 1, 1024]) attention: torch.Size([5, 1, 1024]) torch.Size([5, 1, 1024]) [ubuntu:2697 :0:2697] Caught signal 11 (Segmentation fault: invalid permissions for mapped object at address 0x7f817590a600) [ubuntu:2698 :0:2698] Caught signal 11 (Segmentation fault: invalid permissions for mapped object at address 0x7fee15d0a600) ==== backtrace (tid: 2697) ====
the attention's output shape is output= output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
torchrun --standalone --nproc_per_node=4 tools/example.py -m ./ckpts -t ckpts/tokenizer.model it passed
FMOE_FASTER_SCHEDULE_ENABLE=1 torchrun --standalone --nproc_per_node=4 tools/example.py -m ./ckpts -t ckpts/tokenizer.model it failed

@Xingzhi107
Copy link
Author

Xingzhi107 commented Jan 2, 2024

I'm sorry, I have one more question I would like to ask,
I also have a puzzle in global_policy function

def global_policy(local_expert_count, _gec, num_expert, world_size):
    r"""
    This is the policy for two-layer MLPs, using the formula in the PPoPP paper.
    A few parameters are used in this policy.
    * `d_model`: feature length of the MLP input and output.
    * `alpha`: the ratio of the MLP's hidden size to `d_model`.
    * `bw_net`: bandwidth of the network (GBps)
    * `bw_mm`: computation throughput of performing GeMM (FLOPs)
    """
    bw_net = float_from_env('FMOE_FASTER_GLBPLC_NETBW', 50 * 1e9 / 8)
    bw_mm = float_from_env('FMOE_FASTER_GLBPLC_GPUTP', 11.5e12)
    alpha = float_from_env('FMOE_FASTER_GLBPLC_ALPHA', 2)
    d_model = float_from_env('FMOE_FASTER_GLBPLC_DMODEL', 2048)

    moe_group = get_moe_group()
    local_expert_count = local_expert_count.cuda()
    agecs = [torch.empty_like(local_expert_count) for _ in range(world_size)]
    dist.all_gather(agecs, local_expert_count, group=moe_group)
    all_global_expert_count = torch.stack(agecs)

    # TODO: data type other than float
    data_size = 4

    fwd_expert_counts = all_global_expert_count.sum(1).cpu()
    B_ws, indices = fwd_expert_counts.flatten().sort(0, descending=True)

if the result that local_expert_count gets on each card (worldsize) is the same or different, because I found that the result I got after gather became exactly the same number after sum, so that the resulting res was an all-false tensor
if local_expert_count should be diffierent,Maybe it's because I'm confusing some of the concepts that cause the miscalculations,local_expert_count it was calculated from here fmoe_cuda.expert_count(gate, local_expert_count)
In addition, why does the all false tensor res will lead to the Segmentation fault
Can you give me some advice, thanks
local

@laekov
Copy link
Owner

laekov commented Jan 3, 2024

if the result that local_expert_count gets on each card (worldsize) is the same or different

local_expert_count differs on each GPU, because it includes the counters of samples in the local batch that goes to each expert.

why does the all false tensor res will lead to the Segmentation fault Can you give me some advice, thanks !

res in this function indicates which experts to be shadowed. All false in res means that no expert is being shadowed, which is a common case when the workload is relatively balanced across the experts. I do not think this can lead to a seg fault.

@Xingzhi107
Copy link
Author

``> > if the result that local_expert_count gets on each card (worldsize) is the same or different

local_expert_count differs on each GPU, because it includes the counters of samples in the local batch that goes to each expert.

why does the all false tensor res will lead to the Segmentation fault Can you give me some advice, thanks !

res in this function indicates which experts to be shadowed. All false in res means that no expert is being shadowed, which is a common case when the workload is relatively balanced across the experts. I do not think this can lead to a seg fault.
Thank you very much for your answer, but I get an error when I execute stored_models_[i], there is no way to get its value, but it is possible to print its size,and I don't know what went wrong,its size is always 8,my stored_models is all false tensor which size is num_expert*world_size

std::vector<torch::Tensor> params;
    auto stored_models_ = stored_models.data_ptr<bool>();
    for (long i = 0; i < num_expert * n_workers; ++i) {
        if (stored_models_[i]) {
            torch::Tensor t = input_buf.new_empty({expert_size});
            if (i / num_expert == rank) {
                get_param_fn(t, i % num_expert);
            }
            params.push_back(t);
        }
    }

In addition local_expert_count it is calculated by the function in the FMOE, is it because my use of FMOE is written incorrectly, causing each local_expert_count to be the same?My num_expert is set 1,the world_size is set by os.environ['WORLD_SIZE'],my nnode is 1 and nproc_per_node=4

class Expert(nn.Module):
    def __init__(
        self,
        d_model, d_hidden,
        rank = 0,
    ):
        super().__init__()

        self.w1 = nn.Linear(
            d_model, d_hidden, bias=False
        )
        self.w2 = nn.Linear(
            d_hidden, d_model, bias=False
        )
        self.w3 = nn.Linear(
            d_model, d_hidden, bias=False
        )

    def forward(self, x, fec=None):
        # device = x.device
        # x = x.to(self.w1.weight.device)
        out = self.w2(F.silu(self.w1(x)) * self.w3(x))
        # print(out.shape)
        return out

class FastMoe(FMoE):
    def __init__(self,
                 num_expert=4,
                 d_model = 1024,
                 d_hidden=4096,
                 activation=torch.nn.SiLU(),
                 world_size =1,
                 top_k = 2,
                 # moe_group = 1,
        ):
        def one_expert(d_model):
            return Expert( d_model,d_hidden)
        expert = one_expert
        super().__init__(num_expert, d_model, world_size,
                         top_k=top_k,expert=expert,gate=NaiveGate)
        self.mark_parallel_comm("dp")

    def forward(self, inp: torch.tensor):
        original_shape = inp.shape
        #print("original_shape:",original_shape) #[bsz,seq,d]
        inp = inp.reshape(-1, self.d_model) #[bsz*seq,d]


        # pdb.set_trace()
        output = super().forward(inp)

        return output.reshape(original_shape)

Thank you very much for your guidance again

@laekov
Copy link
Owner

laekov commented Jan 3, 2024

but I get an error when I execute stored_models_[i], there is no way to get its value, but it is possible to print its size

stored_models_ is the output of the policy_fn, which should be a boolean tensor on CPU. As you report that you cannot access its value, I am wondering if you are setting the default device of PyTorch to GPU, which may be problematic in current FastMoE.

In addition local_expert_count it is calculated by the function in the FMOE, is it because my use of FMOE is written incorrectly, causing each local_expert_count to be the same?My num_expert is set 1,the world_size is set by os.environ['WORLD_SIZE'],my nnode is 1 and nproc_per_node=4

local_expert_count being the same is not unusual if your input on each GPU is the same. You may inspect the output of the gate module and see if the select the same experts.

world_size should be equal to the number of GPUs you use. So, in your case, world_size=4 should be correct.

@Xingzhi107
Copy link
Author

stored_models_ is the output of the policy_fn, which should be a boolean tensor on CPU. As you report that you cannot access its value, I am wondering if you are setting the default device of PyTorch to GPU, which may be problematic in current FastMoE.

Yes, indeed for this reason, thank you very much for your help!!!!

@laekov
Copy link
Owner

laekov commented Jan 3, 2024

Well, thank you very much for reporting this issue and ebugging. I think we should explicitly specify the device of tensors when we allocate them in our library. We will update the codebase before closing this issue.

@laekov laekov self-assigned this Jan 3, 2024
@laekov laekov added the bug Something isn't working label Jan 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants