forked from peterwauligmann/PSpaMM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pspamm.py
executable file
·67 lines (42 loc) · 2.24 KB
/
pspamm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/env python3
import argparse
import architecture
from matmul import *
from codegen.ccode import *
from codegen.architectures import *
mtx_formats = ['any','csc','csr','bsc','bsr','bcsc','bcsr']
def main(alg: MatMul) -> None:
block = alg.make()
text = make_cfunc(alg.output_funcname, alg.generator.get_template(), block, alg.flop, alg.starting_regs, alg.generator.get_precision())
if alg.output_filename is None:
print(text)
else:
mode = "a"
if alg.output_overwrite:
mode = "w"
with open(alg.output_filename, mode) as f:
f.write(text)
if __name__=="__main__":
parser = argparse.ArgumentParser(description='Generate a sparse matrix multiplication algorithm for C = alpha * A * B + beta * C.')
parser.add_argument("m", type=int, help="Number of rows of A and C")
parser.add_argument("n", type=int, help="Number of cols of B and C")
parser.add_argument("k", type=int, help="Number of cols of A, rows of B")
parser.add_argument("lda", type=int, help="Leading dimension of A (zero if A is sparse)")
parser.add_argument("ldb", type=int, help="Leading dimension of B (zero if B is sparse)")
parser.add_argument("ldc", type=int, help="Leading dimension of C")
parser.add_argument("alpha", type=str, help="alpha, 1.0 or generic")
parser.add_argument("beta", type=str, help="beta, 1.0, 0.0, or generic")
parser.add_argument("--bm", type=int, help="Size of m-blocks")
parser.add_argument("--bn", type=int, help="Size of n-blocks")
parser.add_argument("--bk", type=int, help="Size of k-blocks")
parser.add_argument("--arch", help="Architecture", default="knl")
parser.add_argument("--precision", help="Single (s) or double (d) precision", default="d")
parser.add_argument("--prefetching", help="Prefetching")
parser.add_argument("--mtx_filename", help="Path to MTX file describing the sparse matrix")
parser.add_argument("--mtx_format", help="Constraint on sparsity pattern", choices=mtx_formats, default="Any")
parser.add_argument("--output_funcname", help="Name for generated C++ function")
parser.add_argument("--output_filename", help="Path to destination C++ file")
parser.add_argument("--output_overwrite", action="store_true", help="Overwrite output file")
args = parser.parse_args()
alg = MatMul(**args.__dict__)
main(alg)