Skip to content

[TorchToTosa] add conv reshape in core lowering#4494

Closed
catcor01 wants to merge 7 commits intollvm:mainfrom
catcor01:conv_reshape
Closed

[TorchToTosa] add conv reshape in core lowering#4494
catcor01 wants to merge 7 commits intollvm:mainfrom
catcor01:conv_reshape

Conversation

@catcor01
Copy link
Copy Markdown
Contributor

@catcor01 catcor01 commented Mar 9, 2026

This change improves Torch-to-TOSA convolution lowering when convolution
operands are not already in the rank required by TOSA.

The lowering now supports two cases:

  • reusing a direct local reshape template on the operand’s own value chain
  • inferring the required 4D/5D input or weight shape from the ranked operands
    and constant convolution attributes when no local template exists

Bias normalization is handled separately and kept TOSA-correct by requiring a
1D bias, or reshaping to 1D when the element count matches the output channel
count.

This is intentionally an internal legalization improvement:

  • it does not change the func.func boundary for real model inputs
  • it does not depend on sibling reshape users for legalization

The tests now include positive before/after IR coverage for direct local reshape
reuse, inferred input normalization, inferred weight normalization, and bias
normalization, while keeping negative coverage for unsupported cases such as
dynamic element counts.

Change-Id: Ica1b5cc265822ecd054f832908ec31bc2325c661

@catcor01
Copy link
Copy Markdown
Contributor Author

catcor01 commented Mar 9, 2026

@sahas3 @Lallapallooza

- Insert rank-4/5 reshapes for conv inputs/weights during TorchToTosa lowering

Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: Ica1b5cc265822ecd054f832908ec31bc2325c661
Change-Id: I5c0c1a5ae2d90cee500dc76247f5952b99bb48f9
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I55752f920f8ad170b3d35c0c8bf5f8b94c4d9de0
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I49732cecaeb7b2ffd3a0e6bf4e74cb0d16aa5e48
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
…e TOSA lowering

Change-Id: Ie4c7767b2155cfa4c81652e60d9698512285de0a
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I9981da1397bbc6221f3f4a6f7c1e1c0f7991bb4b
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: If5a019bc78b81ca7f164f6416f621c0ee1582294
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
@catcor01 catcor01 requested a review from sahas3 March 25, 2026 08:28
np.asarray(tp.float_data, dtype=np.float32).reshape(tp.dims), signless=False
),
onnx.TensorProto.DataType.BOOL: lambda tp: DenseElementsAttr.get(
np.asarray(tp.int32_data, dtype=np.bool_).reshape(tp.dims),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change intended? It'll be good to separate this change into a different PR.

%weight_template = tosa.reshape %weight_builtin, %shape_builtin : (tensor<256xf32>, !tosa.shape<4>) -> tensor<1x1x16x16xf32>
%input_flat = builtin.unrealized_conversion_cast %input_template : tensor<1x1x16x16xf32> to !torch.vtensor<[256],f32>
%weight_flat = builtin.unrealized_conversion_cast %weight_template : tensor<1x1x16x16xf32> to !torch.vtensor<[256],f32>
%conv = torch.aten.convolution %input_flat, %weight_flat, %arg2, %stride, %padding, %dilation, %false, %output_padding, %int1 : !torch.vtensor<[256],f32>, !torch.vtensor<[256],f32>, !torch.vtensor<[1],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,1,1,1],f32>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @catcor01 for adding these positive testcases. I now understand fully what the proposed change is trying to accomplish.

However, before proceeding further I'd like to understand how such torch IR is being generated -- looking at https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv2d.html it seems that conv only accepts 4D/3D tensors as inputs/weights. Is there an upstream pass that flattens the shape like shown here and is that the bug we should try to resolve? Thanks!

@catcor01 catcor01 closed this Apr 13, 2026
@catcor01
Copy link
Copy Markdown
Contributor Author

Closing as not required based on comment #4494 (comment).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants