diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index b79f96d2..46d67aec 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -53,9 +53,9 @@ def forward(ctx, B, T, C, w, u, k, v): assert T <= T_MAX assert B * C % min(C, 32) == 0 w = -torch.exp(w.float().contiguous()) - u = u.contiguous() - k = k.contiguous() - v = v.contiguous() + u = u.bfloat16().contiguous() + k = k.bfloat16().contiguous() + v = v.bfloat16().contiguous() y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) wkv_cuda.forward(B, T, C, w, u, k, v, y) ctx.save_for_backward(w, u, k, v, y) @@ -72,7 +72,7 @@ def backward(ctx, gy): gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) - wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) + wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.bfloat16().contiguous(), gw, gu, gk, gv) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) return (None, None, None, gw, gu, gk, gv)