Skip to content

Commit

Permalink
refactor code and change validation logic to use new annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
ibrokethecloud committed Sep 13, 2024
1 parent b6c8e04 commit a7fb75f
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 45 deletions.
10 changes: 10 additions & 0 deletions pkg/apis/devices.harvesterhci.io/v1beta1/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package v1beta1

const (
DeviceAllocationKey = "harvesterhci.io/deviceAllocationDetails"
)

type AllocationDetails struct {
GPUs map[string][]string `json:"gpus,omitempty"`
HostDevices map[string][]string `json:"hostdevices,omitempty"`
}
20 changes: 7 additions & 13 deletions pkg/controller/virtualmachineinstance/virtualmachineinstance.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ import (
)

const (
kubevirtVMLabelKey = "vm.kubevirt.io/name"
DeviceAllocationKey = "harvesterhci.io/deviceAllocationDetails"
kubevirtVMLabelKey = "vm.kubevirt.io/name"
)

type Handler struct {
Expand All @@ -45,11 +44,6 @@ type Handler struct {
nodeName string
}

type AllocationDetails struct {
GPUs map[string][]string `json:"gpus,omitempty"`
HostDevices map[string][]string `json:"hostdevices,omitempty"`
}

func Register(ctx context.Context, management *config.FactoryManager) error {
vmCache := management.KubevirtFactory.Kubevirt().V1().VirtualMachine().Cache()
vmClient := management.KubevirtFactory.Kubevirt().V1().VirtualMachine()
Expand Down Expand Up @@ -253,12 +247,12 @@ func (h *Handler) checkAndClearDeviceAllocation(vmi *kubevirtv1.VirtualMachineIn
return nil
}

_, ok := vmObj.Annotations[DeviceAllocationKey]
_, ok := vmObj.Annotations[v1beta1.DeviceAllocationKey]
// no key, nothing is needed
if !ok {
return nil
}
delete(vmObj.Annotations, DeviceAllocationKey)
delete(vmObj.Annotations, v1beta1.DeviceAllocationKey)
_, err = h.vmClient.Update(vmObj)
return err
}
Expand All @@ -281,8 +275,8 @@ func buildPCIDeviceMap(pciDevices []*v1beta1.PCIDevice) map[string]string {
return result
}

func generateAllocationDetails(hostDeviceMap, gpuMap map[string][]string) *AllocationDetails {
resp := &AllocationDetails{}
func generateAllocationDetails(hostDeviceMap, gpuMap map[string][]string) *v1beta1.AllocationDetails {
resp := &v1beta1.AllocationDetails{}
if len(hostDeviceMap) > 0 {
hostDeviceMap = dedupDevices(hostDeviceMap)
resp.HostDevices = hostDeviceMap
Expand Down Expand Up @@ -357,11 +351,11 @@ func (h *Handler) reconcileVMResourceAllocationAnnotation(vmi *kubevirtv1.Virtua
if vmObj.Annotations == nil {
vmObj.Annotations = make(map[string]string)
} else {
currentAnnotationValue = vmObj.Annotations[DeviceAllocationKey]
currentAnnotationValue = vmObj.Annotations[v1beta1.DeviceAllocationKey]
}

if currentAnnotationValue != deviceDetails {
vmObj.Annotations[DeviceAllocationKey] = deviceDetails
vmObj.Annotations[v1beta1.DeviceAllocationKey] = deviceDetails
_, err = h.vmClient.Update(vmObj)
}
return err
Expand Down
81 changes: 81 additions & 0 deletions pkg/util/common/common.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package common

import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"

kubevirtv1 "kubevirt.io/api/core/v1"

"github.com/harvester/pcidevices/pkg/apis/devices.harvesterhci.io/v1beta1"
)

const (
Expand Down Expand Up @@ -77,3 +82,79 @@ func GetVFList(pfDir string) (vfList []string, err error) {
}
return
}

// VMByHostDeviceName indexes VM's by host device name.
// It could be usb device claim or pci device claim name.
func VMByHostDeviceName(obj *kubevirtv1.VirtualMachine) ([]string, error) {
if obj.Annotations == nil {
return nil, nil
}

allocationDetails, ok := obj.Annotations[v1beta1.DeviceAllocationKey]
if !ok {
return nil, nil
}

allocatedHostDevices, err := generateHostDeviceAllocation(obj, allocationDetails)
if err != nil {
return nil, err
}

return allocatedHostDevices, nil
}

// VMByVGPUDevice indexes VM's by vgpu names
func VMByVGPUDevice(obj *kubevirtv1.VirtualMachine) ([]string, error) {
// find and add vgpu info from the DeviceAllocationKey annotation if present on the vm
if obj.Annotations == nil {
return nil, nil
}
allocationDetails, ok := obj.Annotations[v1beta1.DeviceAllocationKey]
if !ok {
return nil, nil
}

allocatedGPUs, err := generateGPUDeviceAllocation(obj, allocationDetails)
if err != nil {
return nil, err
}
return allocatedGPUs, nil
}

func generateDeviceAllocationDetails(allocationDetails string) (*v1beta1.AllocationDetails, error) {
currentAllocation := &v1beta1.AllocationDetails{}
err := json.Unmarshal([]byte(allocationDetails), currentAllocation)
return currentAllocation, err
}

func generateDeviceInfo(devices map[string][]string) []string {
var allDevices []string
for _, v := range devices {
allDevices = append(allDevices, v...)
}
return allDevices
}

func generateGPUDeviceAllocation(obj *kubevirtv1.VirtualMachine, allocationDetails string) ([]string, error) {
allocation, err := generateDeviceAllocationDetails(allocationDetails)
if err != nil {
return nil, fmt.Errorf("error generating device allocation details %s/%s: %v", obj.Name, obj.Namespace, err)
}

if allocation.GPUs != nil {
return generateDeviceInfo(allocation.GPUs), nil
}
return nil, nil
}

func generateHostDeviceAllocation(obj *kubevirtv1.VirtualMachine, allocationDetails string) ([]string, error) {
allocation, err := generateDeviceAllocationDetails(allocationDetails)
if err != nil {
return nil, fmt.Errorf("error generating device allocation details %s/%s: %v", obj.Name, obj.Namespace, err)
}

if allocation.HostDevices != nil {
return generateDeviceInfo(allocation.HostDevices), nil
}
return nil, nil
}
23 changes: 15 additions & 8 deletions pkg/util/fakeclients/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package fakeclients

import (
"context"
"slices"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
Expand All @@ -11,6 +12,8 @@ import (

kubevirtv1 "github.com/harvester/harvester/pkg/generated/clientset/versioned/typed/kubevirt.io/v1"
kubevirtctlv1 "github.com/harvester/harvester/pkg/generated/controllers/kubevirt.io/v1"

"github.com/harvester/pcidevices/pkg/util/common"
)

const (
Expand Down Expand Up @@ -88,10 +91,12 @@ func (c VirtualMachineCache) GetByIndex(indexName, key string) ([]*kubevirtv1api
}

for _, vm := range vmList {
for _, hostDevice := range vm.Spec.Template.Spec.Domain.Devices.HostDevices {
if hostDevice.Name == key {
vms = append(vms, vm)
}
deviceInfo, err := common.VMByHostDeviceName(vm)
if err != nil {
return nil, err
}
if slices.Contains(deviceInfo, key) {
vms = append(vms, vm)
}
}
return vms, nil
Expand All @@ -103,10 +108,12 @@ func (c VirtualMachineCache) GetByIndex(indexName, key string) ([]*kubevirtv1api
}

for _, vm := range vmList {
for _, gpuDevice := range vm.Spec.Template.Spec.Domain.Devices.GPUs {
if gpuDevice.Name == key {
vms = append(vms, vm)
}
deviceInfo, err := common.VMByVGPUDevice(vm)
if err != nil {
return nil, err
}
if slices.Contains(deviceInfo, key) {
vms = append(vms, vm)
}
}
return vms, nil
Expand Down
26 changes: 4 additions & 22 deletions pkg/webhook/indexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
kubevirtv1 "kubevirt.io/api/core/v1"

"github.com/harvester/pcidevices/pkg/apis/devices.harvesterhci.io/v1beta1"
"github.com/harvester/pcidevices/pkg/util/common"
)

const (
Expand All @@ -21,9 +22,9 @@ const (
func RegisterIndexers(clients *Clients) {
vmCache := clients.KubevirtFactory.Kubevirt().V1().VirtualMachine().Cache()
vmCache.AddIndexer(VMByName, vmByName)
vmCache.AddIndexer(VMByPCIDeviceClaim, vmByHostDeviceName)
vmCache.AddIndexer(VMByUSBDeviceClaim, vmByHostDeviceName)
vmCache.AddIndexer(VMByVGPU, vmByVGPUDevice)
vmCache.AddIndexer(VMByPCIDeviceClaim, common.VMByHostDeviceName)
vmCache.AddIndexer(VMByUSBDeviceClaim, common.VMByHostDeviceName)
vmCache.AddIndexer(VMByVGPU, common.VMByVGPUDevice)
deviceCache := clients.DeviceFactory.Devices().V1beta1().PCIDevice().Cache()
deviceCache.AddIndexer(PCIDeviceByResourceName, pciDeviceByResourceName)
deviceCache.AddIndexer(IommuGroupByNode, iommuGroupByNodeName)
Expand All @@ -45,25 +46,6 @@ func iommuGroupByNodeName(obj *v1beta1.PCIDevice) ([]string, error) {
return []string{fmt.Sprintf("%s-%s", obj.Status.NodeName, obj.Status.IOMMUGroup)}, nil
}

// vmByHostDeviceName indexes VM's by host device name.
// It could be usb device claim or pci device claim name.
func vmByHostDeviceName(obj *kubevirtv1.VirtualMachine) ([]string, error) {
hostDeviceName := make([]string, 0, len(obj.Spec.Template.Spec.Domain.Devices.HostDevices))
for _, hostDevice := range obj.Spec.Template.Spec.Domain.Devices.HostDevices {
hostDeviceName = append(hostDeviceName, hostDevice.Name)
}
return hostDeviceName, nil
}

// vmByVGPUDevice indexes VM's by vgpu names
func vmByVGPUDevice(obj *kubevirtv1.VirtualMachine) ([]string, error) {
gpuNames := make([]string, 0, len(obj.Spec.Template.Spec.Domain.Devices.GPUs))
for _, gpuDevice := range obj.Spec.Template.Spec.Domain.Devices.GPUs {
gpuNames = append(gpuNames, gpuDevice.Name)
}
return gpuNames, nil
}

func usbDeviceClaimByAddress(obj *v1beta1.USBDeviceClaim) ([]string, error) {
return []string{fmt.Sprintf("%s-%s", obj.Status.NodeName, obj.Status.PCIAddress)}, nil
}
3 changes: 3 additions & 0 deletions pkg/webhook/usbdeviceclaim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ var (
ObjectMeta: metav1.ObjectMeta{
Name: "vm-with-usb-devices",
Namespace: "default",
Annotations: map[string]string{
devicesv1beta1.DeviceAllocationKey: `{"hostdevices":{"fake.com/device1":["usbdevice1"]}}`,
},
},
Spec: kubevirtv1.VirtualMachineSpec{
Template: &kubevirtv1.VirtualMachineInstanceTemplateSpec{
Expand Down
5 changes: 3 additions & 2 deletions pkg/webhook/vgpu.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ import (

"github.com/sirupsen/logrus"

kubevirtctl "github.com/harvester/harvester/pkg/generated/controllers/kubevirt.io/v1"
"github.com/harvester/harvester/pkg/webhook/types"
admissionregv1 "k8s.io/api/admissionregistration/v1"
"k8s.io/apimachinery/pkg/runtime"

kubevirtctl "github.com/harvester/harvester/pkg/generated/controllers/kubevirt.io/v1"
"github.com/harvester/harvester/pkg/webhook/types"

devicesv1beta1 "github.com/harvester/pcidevices/pkg/apis/devices.harvesterhci.io/v1beta1"
)

Expand Down
4 changes: 4 additions & 0 deletions pkg/webhook/vgpu_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ var (
ObjectMeta: metav1.ObjectMeta{
Name: "vgpu-vm",
Namespace: "default",
Annotations: map[string]string{
devicesv1beta1.DeviceAllocationKey: `{"gpus":{"nvidia.com/fakevgpu":["vgpu1"]}}`,
},
},
Spec: kubevirtv1.VirtualMachineSpec{
Template: &kubevirtv1.VirtualMachineInstanceTemplateSpec{
Expand Down Expand Up @@ -173,6 +176,7 @@ func Test_VGPUDeletion(t *testing.T) {
vGPUValidator := NewVGPUValidator(virtualMachineCache)
for _, v := range testCases {
err := vGPUValidator.Delete(nil, v.gpu)
t.Log(err)
if v.expectError {
assert.Error(err, fmt.Sprintf("expected to find error for test case %s", v.name))
} else {
Expand Down
6 changes: 6 additions & 0 deletions pkg/webhook/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ var (
ObjectMeta: metav1.ObjectMeta{
Name: "vm-with-iommu-devices",
Namespace: "default",
Annotations: map[string]string{
devicesv1beta1.DeviceAllocationKey: `{"hostdevices":{"fake.com/device1":["node1dev1"]}}`,
},
},
Spec: kubevirtv1.VirtualMachineSpec{
Template: &kubevirtv1.VirtualMachineInstanceTemplateSpec{
Expand All @@ -130,6 +133,9 @@ var (
ObjectMeta: metav1.ObjectMeta{
Name: "vm-with-iommu-devices",
Namespace: "default",
Annotations: map[string]string{
devicesv1beta1.DeviceAllocationKey: `{"hostdevices":{"fake.com/device1":["node1dev1"],"fake.com/device2":["node1dev2"]}}`,
},
},
Spec: kubevirtv1.VirtualMachineSpec{
Template: &kubevirtv1.VirtualMachineInstanceTemplateSpec{
Expand Down

0 comments on commit a7fb75f

Please sign in to comment.