Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2D matmul with specific shapes gives bad PCC on a TG #10936

Closed
johanna-rock-tt opened this issue Jul 31, 2024 · 39 comments
Closed

2D matmul with specific shapes gives bad PCC on a TG #10936

johanna-rock-tt opened this issue Jul 31, 2024 · 39 comments
Assignees
Labels
bug Something isn't working llama3 LLM_bug llm_tg LLMs on Metal op_cat: mm P1 prefill LLM models have prefill mode and it's optimization is usually separated from decode mode.

Comments

@johanna-rock-tt
Copy link
Contributor

johanna-rock-tt commented Jul 31, 2024

Describe the bug
A 2D fractured matmul with specific shapes gives bad PCC when run on a galaxy.
Shape is for llama3-405B FF1 for prefill with sequence length = 512, both activation and weight are in DRAM interleaved format. The matmul works with lower sequence lengths (=M), e.g. 128 or 256.

M = 512, K = 16 * 1024, N = 52 * 1024

The PCC is (so far) always around zero but still non-deterministic, e.g. in three runs:

PCC: 0.0002399902230190794
PCC: -0.00038166219156071467
PCC: -0.00017789431408447355

To Reproduce
Steps to reproduce the behavior:

  1. Reproduce on a TG
  2. Checkout branch jrock/llama3-405b
  3. Run unit test pytest tests/ttnn/multichip_unit_tests/test_multidevice_TG.py::test_galaxy_matmul_2d_fracture[silicon_arch_name=wormhole_b0-silicon_arch_wormhole_b0=True-Llama3-405B_prefill_seq512_FF1-4x8_grid]

Expected behavior
Passes the test with 0.99 PCC target.

@johanna-rock-tt johanna-rock-tt added bug Something isn't working op_cat: mm LLM_bug LLMs on Metal prefill LLM models have prefill mode and it's optimization is usually separated from decode mode. llama3 labels Jul 31, 2024
@johanna-rock-tt
Copy link
Contributor Author

I think this might be a di/dt issue @pavlepopovic mentioned that they are seeing similar behaviour for di/dt on galaxies.

@davorchap
Copy link
Collaborator

@johanna-rock-tt is the grid size 8x4 (32 cores) ?

looping in @TT-BrianLiu and @tt-aho

is this related to ND PCC here #10673 ?

@davorchap
Copy link
Collaborator

also, is this a Galaxy-specific problem, does the same shape/config pass on N300?

@cglagovichTT cglagovichTT self-assigned this Aug 9, 2024
@cglagovichTT
Copy link
Contributor

I ran Johanna's test on 4x8 and 8x8 grids on TG with default program config, reproed ND PCC.

4x8 matmul_2d with subblock 1x1 passes 50 iterations of testing.

This test uses matmul_2d and #10673 uses dram sharded matmul. They may be related if galaxy is fragile with matmul di/dt.

Note that when I first started testing, the 4x8 matmul_2d with default program config failed very consistently, >50% of the time. After the machine warmed up, I found the 4x8 failures to be much less frequent, about 5% of the time. Is there any systems explanation for this?

@davorchap
Copy link
Collaborator

I ran Johanna's test on 4x8 and 8x8 grids on TG with default program config, reproed ND PCC.

4x8 matmul_2d with subblock 1x1 passes 50 iterations of testing.

This test uses matmul_2d and #10673 uses dram sharded matmul. They may be related if galaxy is fragile with matmul di/dt.

Note that when I first started testing, the 4x8 matmul_2d with default program config failed very consistently, >50% of the time. After the machine warmed up, I found the 4x8 failures to be much less frequent, about 5% of the time. Is there any systems explanation for this?

Is the matmul running on 32x devices when it fails?

Would be great to run the same test on t3k, to see if it's a Galaxy specific issue

@cglagovichTT
Copy link
Contributor

yes the matmul is running on 32 devices when it fails. I'll check on t3k

@davorchap
Copy link
Collaborator

yes the matmul is running on 32 devices when it fails. I'll check on t3k

Would be great to see also what happen when it runs on 1, 2, 4, 8 on TG

@cglagovichTT
Copy link
Contributor

I ran the failing 4x8 coregrid with default program config on T3K 8 chips and got deterministic good PCC.

@pavlepopovic
Copy link
Contributor

pavlepopovic commented Aug 12, 2024

I've managed to run this MM with exactly the same config on t3k as on galaxy this morning, non-determinism occurs every iteration:
image
Now, there's something odd that I've spotted while running this on Galaxy with tracy:

  • This runs a 13 core MM (not 32) (it appears that ttnn wants to do matmul1d, where width gets laid out on cores: (208 tiles in W dimension split over 13 cores)
  • ttnn choses these kernels to run the matmul with:
    • compute ['ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm.cpp']
    • datamovement: ['ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout.cpp'; 'ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/writer_bmm_tile_layout.cpp']

I don't know about DM kernels, but the compute kernel seems fishy - it's using bmm_large_block_zm.cpp kernel, and I know that we usually are using this one: bmm_large_block_zm_fused_bias_activation.cpp
The one used seems like a simpler version of the bmm_large_block_zm_fused_bias_activation.cpp.

Attaching the entire tracy run for galaxy:
llama_nd_mm.csv

@uaydonat
Copy link
Contributor

specify 1d/2d matmul config to fix

@uaydonat
Copy link
Contributor

@bbradelTT to identify why bad matmul is used

@bbradelTT
Copy link
Contributor

3. test_galaxy_matmul_2d_fracture

@pavlepopovic branch jrock/llama3-405b doesn't exist. Is the code in main now? If not, what steps did you use to reproduce the issue?

@TT-BrianLiu
Copy link
Contributor

Couple of things:

  • Shouldn't draw any conclusions out of tests that have 0.000 PCC. The output is probably complete garbage. "ND" could come from reading/writing to undefined space which might change per run.
  • Try to use a specific variant instead of default matmul in models. You can bypass the automatic paths that ttnn.matmul takes. These paths become more and more involved as we try to support more behaviour.
  • Try to run your matmul for single chip (with properly scaled down specs) and see if that works. If PCC is good, then go to multi-chip

@bbradelTT
Copy link
Contributor

@pavlepopovic branch jrock/llama3-405b doesn't exist. Is the code in main now? If not, what steps did you use to reproduce the issue?

Talked to @pavlepopovic
The code is already in main in tests/ttnn/multichip_unit_tests/test_multidevice_TG.py
It's just commented out and the file needs to be updated. See

# pytest.param(
        #     512, 16 * 1024, 52 * 1024, ttnn.bfloat4_b, id="Llama3-405B_prefill_seq512_FF1"
        # ),  # PCC check failed, PCC: -0.00014127559109112134, see issue 10936

Also, we've started seeing similar behaviour in sweeps recently. See #9059 (comment)

@bbradelTT
Copy link
Contributor

My current theory is that a fallthrough path of automatically choosing parameters that was added as a last resort is distributing tensors while ignoring their shard shapes (e.g. if the shard shape is 768 and m is 768, 768/32=24 cores will try to be used instead of 1) and the underlying kernels are not expecting such settings. I'm continuing to investigate.

@pavlepopovic
Regarding your comment: #10936 (comment) : "I've managed to run this MM with exactly the same config on t3k as on galaxy this morning, non-determinism occurs every iteration:"

Could you please

  • send me your test and how to run it?
  • if your matmul call has core_grid, remove the core_grid or if the matmul call does not have core_grid try to pass in something that takes into account the tensors' shard shapes?

When I tried to specify core_grid in tests/ttnn/multichip_unit_tests/test_multidevice_TG.py::test_galaxy_matmul_2d_fracture on Galaxy I got an exception Statically allocated circular buffers on core range [(x=0,y=0) - (x=7,y=0)] grow to 4191072 B which is beyond max L1 size of 1499136 B
I'm not sure if you'd get the same thing.

@bbradelTT
Copy link
Contributor

@johanna-rock-tt
You can use
something like the following for that specific set up:

            program_config=ttnn.MatmulMultiCoreReuseMultiCastProgramConfig(
                compute_with_storage_grid_size=(8, 8),
                in0_block_w=1,
                out_subblock_h=1,
                out_subblock_w=1,
                per_core_M=2,
                per_core_N=26,
                transpose_mcast=False,
                fused_activation=None,
            ),

E.g.

 # Llama FF1, FF2, FF3 in MLP with dram interleaved weights
@@ -146,6 +146,17 @@ def test_galaxy_matmul_2d_fracture(M, K, N, weights_dtype, mesh_shape, device_me
         act,
         weights,
         dtype=ttnn.bfloat16,
+        #core_grid=core_grid,
+        program_config=ttnn.MatmulMultiCoreReuseMultiCastProgramConfig(
+                compute_with_storage_grid_size=(8, 8),
+                in0_block_w=1,
+                out_subblock_h=1,
+                out_subblock_w=1,
+                per_core_M=2,
+                per_core_N=26,
+                transpose_mcast=False,
+                fused_activation=None,
+        ),
         compute_kernel_config=compute_kernel_lofi,
         memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG if M == 32 else ttnn.DRAM_MEMORY_CONFIG,
     )

Leads to

...
2024-08-18 17:01:12.511 | DEBUG    | conftest:device_mesh:193 - multidevice with 32 devices is created
2024-08-18 17:02:29.534 | INFO     | tests.ttnn.multichip_unit_tests.test_multidevice_TG:test_galaxy_matmul_2d_fracture:168 - PCC value: Max ATOL Delta: 80.65719604492188, Max RTOL Delta: inf, PCC: 0.9931608583682504
PASSED2024-08-18 17:02:31.086 | DEBUG    | ttnn:manage_config:93 - Restored ttnn.CONFIG.report_name to None
...
PASSED tests/ttnn/multichip_unit_tests/test_multidevice_TG.py::test_galaxy_matmul_2d_fracture[silicon_arch_name=wormhole_b0-silicon_arch_wormhole_b0=True-Llama3_405B_prefill_seq512_FF1-8x4_grid]

I still need to investigate why the default is resulting in incorrect behaviour.

@bbradelTT
Copy link
Contributor

I also tried the following, which lead to nothing happening (probably a hang):

            ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig(
                compute_with_storage_grid_size=(2, 8),
                in0_block_w=1,  # K // 32
                out_subblock_h=4,  # 8 // (N // 32)
                out_subblock_w=2,  # N // 32
                per_core_M=16,  # M // 32
                per_core_N=13,  # N // num_cores(16) // 32
                fuse_batch=False,
                fused_activation=None,
                mcast_in0=True,
            ),

In terms of what the code chooses by default, I got the following debug output:

                     Op | DEBUG    | Auto generated program config: MatmulMultiCoreNonOptimizedReuseProgramConfig()
...
                     Op | DEBUG    |    0: Tensor(storage=DeviceStorage(memory_config=MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::DRAM,shard_spec=std::nullopt)),shape=ttnn.Shape([1, 1, 512, 4096]),dtype=DataType::BFLOAT16,layout=Layout::TILE)
                     Op | DEBUG    |    1: Tensor(storage=DeviceStorage(memory_config=MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::DRAM,shard_spec=std::nullopt)),shape=ttnn.Shape([1, 1, 4096, 6656]),dtype=DataType::BFLOAT4_B,layout=Layout::TILE)

for all 32 devices.

Next steps:

  • hopefully a specific program config can be used to bypass this issue - maybe the 2D one I suggested
  • @pavlepopovic's 1D test needs to be investigated
  • the 1D set up on Galaxy leading to a freeze / nothing happening should be investigated
  • the behaviour of MatmulMultiCoreNonOptimizedReuseProgramConfig should be investigated

@bbradelTT
Copy link
Contributor

The following reproduces the hang I saw on Galaxy on WH:

import torch
import ttnn
import time
from tracy import Profiler
device_id = 0
device = ttnn.open_device(device_id=device_id)
a=ttnn.from_torch(25*torch.ones([1,1,512,4096], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)
b=ttnn.from_torch(125*torch.ones([1,1,4096,6656], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)
c = ttnn.matmul(a,b, program_config=ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig(
                compute_with_storage_grid_size=(2, 8),
                in0_block_w=1,
                out_subblock_h=4,
                out_subblock_w=2,
                per_core_M=16,
                per_core_N=13,
                fuse_batch=False,
                fused_activation=None,
                mcast_in0=True,
            ))
ttnn.close_device(device)
exit()

@bbradelTT
Copy link
Contributor

Talked with @pavlepopovic
He just extracted the dimensions of the tensors and created a unit test for 1 chip with those dimensions, that just calls ttnn.matmul() without a specified config.

Based on this info, the code would choose MatmulMultiCoreNonOptimized. in0 is pretty narrow, and fits within 16 tiles, and therefore it looks like a 1d matmul.

@bbradelTT
Copy link
Contributor

@johanna-rock-tt
are you sure the bad PCC is not just a function of using bfloat4_b with K=4096?

I replaced bfloat4_b with bfloat16 and the test passed:

pytest.param(
             512, 16 * 1024, 52 * 1024, ttnn.bfloat16, id="Llama3_405B_prefill_seq512_FF1"
        ),

Output:

=============================================================================== short test summary info ===============================================================================
PASSED tests/ttnn/multichip_unit_tests/test_multidevice_TG.py::test_galaxy_matmul_2d_fracture[silicon_arch_name=wormhole_b0-silicon_arch_wormhole_b0=True-Llama3_405B_prefill_seq512_FF1-8x4_grid]
============================================================================ 1 passed, 1 warning in 55.78s ============================================================================

Also, I tried first bfloat4_b and then bfloat16 with some local tests and the results are quite different.

...
>>> a=ttnn.from_torch(25*torch.ones([1,1,512,4096], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)
>>> b=ttnn.from_torch(125*torch.ones([1,1,4096,6656], dtype=torch.bfloat16), dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, device=device)
>>> c=ttnn.matmul(a,b)
>>> print(c)
ttnn.Tensor([[[[130928957584814213481024995437613940736.00000, 130928957584814213481024995437613940736.00000,  ..., 130264343586921755544573091907473768448.00000, 130264343586921755544573091907473768448.00000],
               [130928957584814213481024995437613940736.00000, 130928957584814213481024995437613940736.00000,  ..., 130264343586921755544573091907473768448.00000, 130264343586921755544573091907473768448.00000],
               ...,
               [130264343586921755544573091907473768448.00000, 130264343586921755544573091907473768448.00000,  ..., 130264343586921755544573091907473768448.00000, 130264343586921755544573091907473768448.00000],
               [130264343586921755544573091907473768448.00000, 130264343586921755544573091907473768448.00000,  ..., 128935115591136839671669284847193423872.00000, 128935115591136839671669284847193423872.00000]]]], shape=Shape([1, 1, 512, 6656]), dtype=DataType::BFLOAT16, layout=Layout::TILE)
>>> b=ttnn.from_torch(125*torch.ones([1,1,4096,6656], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device)
>>> c=ttnn.matmul(a,b)
>>> print(c)
ttnn.Tensor([[[[15400960.00000, 15400960.00000,  ..., 15400960.00000, 15400960.00000],
               [15400960.00000, 15400960.00000,  ..., 15400960.00000, 15400960.00000],
               ...,
               [15400960.00000, 15400960.00000,  ..., 15400960.00000, 15400960.00000],
               [15400960.00000, 15400960.00000,  ..., 15400960.00000, 15400960.00000]]]], shape=Shape([1, 1, 512, 6656]), dtype=DataType::BFLOAT16, layout=Layout::TILE)

@bbradelTT
Copy link
Contributor

One of the other tests that passes has smaller K and N half of the failing test:

pytest.param(512, 8192, 28 * 1024, ttnn.bfloat4_b, id="Llama3-70B_prefill_seq512_FF1"),

When N is 52 * 1024 then K has to be very small before the test passes. I had to set K to 128 for the failing test to pass.

@bbradelTT
Copy link
Contributor

For the 1D matmul, out_subblock_w needs to be 1.
I thought validate checked for that already. Will need to see why it didn't, but that part is working as expected.

@bbradelTT
Copy link
Contributor

validate did check for what I expected, but there was a bug in the code where validation was disabled when program cache is off. This bug has been fixed yesterday.

I verified with the following that this is a bfloat4_b precision issue.

    in0=1*torch.eye(1024)
    in1=1*torch.ones([1,1,1024,6656])
    a=ttnn.from_torch(in0, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat4_b)
    b=ttnn.from_torch(in1, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat4_b, device=device)
    c = ttnn.matmul(a,b)
    print(c)
    d = torch.matmul(in0,in1)
    print(d)

uses MatmulMultiCoreNonOptimizedReuseProgramConfig and both c and d are all 1s

I also verified MatmulMultiCoreProgramConfig is okay as well with the following:

    in0=1*torch.eye(4096)
    in1=1*torch.ones([1,1,4096,6656])
    a=ttnn.from_torch(in0, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat4_b)
    b=ttnn.from_torch(in1, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat4_b, device=device)
    c = ttnn.matmul(a,b)
    print(c)
    d = torch.matmul(in0,in1)
    print(d)

@johanna-rock-tt you'll need to use the program config or not have this specific shape+precision.

@johanna-rock-tt
Copy link
Contributor Author

Thanks @bbradelTT
I verified that your proposed program config works!

However I'm still puzzled by bfp4 with this specific shape producing ~0.0 PCC while similar shapes in bfp 4 (e.g. M = 256, K = 16 * 1024, N = 52 * 1024) produce 0.99 pcc.

@bbradelTT
Copy link
Contributor

@johanna-rock-tt
I think once the tensors get big enough there is overflow and rounding happening, where order of operation has a big impact.

The tensor has some common info for bfloat4_b and bfloat8_b. Based on what I saw, probably at least for the mantissa.

E.g. I would get sets of values such as 0.50000, 0.25000, ..., 0.12500, 0.25000 or 512.00000, 384.00000, ..., 128.00000, 896.00000, with nothing between set boundaries.

If that common info is set based on inputs or some criteria that can't handle the right sets of ranges, and then there is enough addition, the values probably drift off.

It'd be good to have a reference for bfloat4_b, and then we could do more than speculate.

@johanna-rock-tt
Copy link
Contributor Author

I see, thanks for the explanation!

@uaydonat
Copy link
Contributor

bad pcc could be explained by overflows, but do we have an explanation for the ND pcc that was initially reported?

@bbradelTT
Copy link
Contributor

@uaydonat I don't see a fixed seed in the test. Wouldn't we always get slightly different pccs each time in that case?

@uaydonat
Copy link
Contributor

@uaydonat I don't see a fixed seed in the test. Wouldn't we always get slightly different pccs each time in that case?

Yes, we would.

@johanna-rock-tt please check in your top comment, the test you mentioned is setting the seed or not. If not, please check the ND goes away with seed.

@johanna-rock-tt
Copy link
Contributor Author

Just checked. We didn't have a manual seed set, but setting the seed (torch.manual_seed(1234)) still results in ND PCC for me:

PCC value: Max ATOL Delta: nan, Max RTOL Delta: nan, PCC: -0.00031071428310952017, PCC check failed
PCC value: Max ATOL Delta: nan, Max RTOL Delta: nan, PCC: -0.00024354556162184027, PCC check failed
PCC value: Max ATOL Delta: nan, Max RTOL Delta: nan, PCC: -0.00013977039323418247, PCC check failed

@bbradelTT
Copy link
Contributor

@johanna-rock-tt for random PCCs

  1. what about the other test cases?
  2. what if you use the program config that gives a good PCC?

@bbradelTT
Copy link
Contributor

I'm wondering if the remaining ND PCC behaviour is related to #10673

@johanna-rock-tt
Copy link
Contributor Author

The other shapes of the same test as well as the problematic shape with the program config that gives good PCC have deterministic PCC (tested with 3 runs each).

@bbradelTT
Copy link
Contributor

Interesting. The kernels used by MatmulMultiCoreNonOptimized may not be configured properly to handle bfloat4_b that shows up when rounding is involved.

I'll have to look into that.

@bbradelTT
Copy link
Contributor

I'm trying with a sample test:

a=ttnn.from_torch(1*torch.ones([1,1,512,4096], dtype=torch.bfloat16), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)
b=ttnn.from_torch(1*torch.ones([1,1,4096,6656], dtype=torch.bfloat16), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) # result in incorrect output
c = ttnn.matmul(a,b)
print(c)

where the dtype=ttnn.bfloat... changes. The output should be all 4096 values.

in0 dtype >= in1 dtype:

  • bfloat16_b and bfloat16_b - correct results
  • bfloat16_b and bfloat8_b - incorrect results
  • bfloat16_b and bfloat4_b - incorrect results
  • bfloat8_b and bfloat8_b - correct results
  • bfloat8_b and bfloat4_b - incorrect results
  • bfloat4_b and bfloat4_b - incorrect results

in0 dtype < in1 dtype:

  • bfloat8_b and bfloat16_b - incorrect results
  • bfloat4_b and bfloat16_b - incorrect results
  • bfloat4_b and bfloat8_b - incorrect results

Compute kernel with the issue:
ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm.cpp

Compute kernel used by 2d mcast that does not have the issue:
ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp

I'll try to figure out what the difference is that could be causing the different behaviour.

@bbradelTT
Copy link
Contributor

I talked to @tt-aho
Turns out that mixed precision was not supported by the kernel.
@tt-aho created a commit a184269 that addressed this issue and I created a PR #11947 to run all the pipelines, etc.

I verified that after this change the test passes and using torch.manual_seed(1234) and running a couple of times produces the same PCC:

2024-08-27 14:55:05.818 | INFO     | tests.ttnn.multichip_unit_tests.test_multidevice_TG:test_galaxy_matmul_2d_fracture:159 - PCC value: Max ATOL Delta: 105.75125122070312, Max RTOL Delta: inf, PCC: 0.9926411554984379
2024-08-27 14:58:02.574 | INFO     | tests.ttnn.multichip_unit_tests.test_multidevice_TG:test_galaxy_matmul_2d_fracture:159 - PCC value: Max ATOL Delta: 105.75125122070312, Max RTOL Delta: inf, PCC: 0.9926411554984379

@johanna-rock-tt
Copy link
Contributor Author

That's great! Thanks for investigating and fixing!

bbradelTT added a commit that referenced this issue Aug 27, 2024
@bbradelTT
Copy link
Contributor

@johanna-rock-tt the fix is merged. Please verify that everything works as expected and then you can uncomment the relevant test parameter combination.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working llama3 LLM_bug llm_tg LLMs on Metal op_cat: mm P1 prefill LLM models have prefill mode and it's optimization is usually separated from decode mode.
Projects
None yet
Development

No branches or pull requests

7 participants