-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix for swin_transformer(matmul+transpose+reshape) #35740
Fix for swin_transformer(matmul+transpose+reshape) #35740
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@piotrekobiIntel @tsocha Could you please review and maybe advise me if that operation of infering missing dim can be done more concisely? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I would suggest using a simple for loop to check if "-1" is present. That way you don't need to iterate over the array twice to find presence and later distance. I don't know how often "-1" will be present, but if it's often enough it might even be better to calculate the product in that for loop instead of iterating over the array once again. |
// if "-1" is present then one of reshape dims must be infered | ||
if (it != reshape_out.end()) { | ||
int index = std::distance(reshape_out.begin(), it); | ||
|
||
auto ddim_out_vec = framework::vectorize(ddim_out); | ||
|
||
int ddim_out_product = | ||
std::accumulate(ddim_out_vec.begin(), ddim_out_vec.end(), 1, | ||
std::multiplies<int>()); | ||
int reshape_out_product = std::accumulate( | ||
reshape_out.begin(), reshape_out.end(), -1, std::multiplies<int>()); | ||
|
||
reshape_out[index] = ddim_out_product / reshape_out_product; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You modify a reshape_out dim which is already pointed by iterator.
There is no need to calculate an index via std::distance, just use 'it'.
I wonder if it is worth to check shapes compatibility:
'''
ddim_out_product % reshape_out_product == 0
'''
PR types
Bug fixes
PR changes
OPs
Describe
Fix for #35719. Problem was caused because matmul+transpose+reshape was not handling dim inference which is used by reshape op. Other problem was caused because we weren't handling 6 dim tensors, but now we do.