Skip to content

Commit

Permalink
use context timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
gamoutatsumi committed Apr 16, 2024
1 parent c2a3e22 commit e214990
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 25 deletions.
73 changes: 52 additions & 21 deletions server/pkg/api/pool.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package api

import (
"context"
"errors"
"fmt"
"log"
"log/slog"
"math/rand"
"sort"
"strconv"
Expand All @@ -17,19 +19,36 @@ import (
"github.com/whywaita/shoes-lxd-multi/server/pkg/metric"
)

func getInstancesWithTimeout(h lxdclient.LXDHost, d time.Duration) (s []api.Instance, overCommitPercent uint64, err error) {
done := make(chan struct{})
type gotInstances struct {
Instances []api.Instance
OverCommitPercent uint64
Error error
}

func getInstancesWithTimeout(_ctx context.Context, h lxdclient.LXDHost, d time.Duration, l *slog.Logger) ([]api.Instance, uint64, error) {
ret := make(chan *gotInstances)
ctx, cancel := context.WithTimeout(_ctx, d)
defer cancel()
go func() {
defer close(done)
s, err = lxdclient.GetAnyInstances(h.Client)
defer close(ret)
s, err := lxdclient.GetAnyInstances(h.Client)
if err != nil {
ret <- &gotInstances{
Instances: nil,
OverCommitPercent: 0,
Error: fmt.Errorf("failed to get instances: %w", err),
}
return
}
r, err := lxdclient.GetResource(h.HostConfig)
r, err := lxdclient.GetResource(ctx, h.HostConfig, l)
if err != nil {
ret <- &gotInstances{
Instances: nil,
OverCommitPercent: 0,
Error: fmt.Errorf("failed to get resource: %w", err),
}
return
}

var used uint64
for _, i := range s {
if i.StatusCode != api.Running {
Expand All @@ -41,18 +60,30 @@ func getInstancesWithTimeout(h lxdclient.LXDHost, d time.Duration) (s []api.Inst
}
cpu, err := strconv.Atoi(i.Config["limits.cpu"])
if err != nil {
err = fmt.Errorf("failed to parse limits.cpu: %w", err)
ret <- &gotInstances{
Instances: nil,
OverCommitPercent: 0,
Error: fmt.Errorf("failed to parse limits.cpu: %w", err),
}
return
}
used += uint64(cpu)
}
overCommitPercent = uint64(float64(used) / float64(r.CPUTotal) * 100)
overCommitPercent := uint64(float64(used) / float64(r.CPUTotal) * 100)
ret <- &gotInstances{
Instances: s,
OverCommitPercent: overCommitPercent,
Error: nil,
}
}()
select {
case <-done:
return
case <-time.After(d):
return nil, 0, errors.New("timed out")

for {
select {
case <-ctx.Done():
return nil, 0, errors.New("timed out")
case r := <-ret:
return r.Instances, r.OverCommitPercent, r.Error
}
}
}

Expand All @@ -61,7 +92,7 @@ type instance struct {
InstanceName string
}

func findInstances(targets []lxdclient.LXDHost, match func(api.Instance) bool, limitOverCommit uint64) []instance {
func findInstances(ctx context.Context, targets []lxdclient.LXDHost, match func(api.Instance) bool, limitOverCommit uint64, l *slog.Logger) []instance {
type result struct {
host *lxdclient.LXDHost
overCommitPercent uint64
Expand All @@ -75,7 +106,7 @@ func findInstances(targets []lxdclient.LXDHost, match func(api.Instance) bool, l
go func(i int, target lxdclient.LXDHost) {
defer wg.Done()

s, overCommitPercent, err := getInstancesWithTimeout(target, 10*time.Second)
s, overCommitPercent, err := getInstancesWithTimeout(ctx, target, 10*time.Second, l)
if err != nil {
log.Printf("failed to find instance in host %q: %+v", target.HostConfig.LxdHost, err)
return
Expand Down Expand Up @@ -123,18 +154,18 @@ func findInstances(targets []lxdclient.LXDHost, match func(api.Instance) bool, l
return instances
}

func findInstanceByJob(targets []lxdclient.LXDHost, runnerName string) (*lxdclient.LXDHost, string, bool) {
s := findInstances(targets, func(i api.Instance) bool {
func findInstanceByJob(ctx context.Context, targets []lxdclient.LXDHost, runnerName string, l *slog.Logger) (*lxdclient.LXDHost, string, bool) {
s := findInstances(ctx, targets, func(i api.Instance) bool {
return i.Config[lxdclient.ConfigKeyRunnerName] == runnerName
}, 0)
}, 0, l)
if len(s) < 1 {
return nil, "", false
}
return s[0].Host, s[0].InstanceName, true
}

func allocatePooledInstance(targets []lxdclient.LXDHost, resourceType, imageAlias string, limitOverCommit uint64, runnerName string) (*lxdclient.LXDHost, string, error) {
s := findInstances(targets, func(i api.Instance) bool {
func allocatePooledInstance(ctx context.Context, targets []lxdclient.LXDHost, resourceType, imageAlias string, limitOverCommit uint64, runnerName string, l *slog.Logger) (*lxdclient.LXDHost, string, error) {
s := findInstances(ctx, targets, func(i api.Instance) bool {
if i.StatusCode != api.Frozen {
return false
}
Expand All @@ -148,7 +179,7 @@ func allocatePooledInstance(targets []lxdclient.LXDHost, resourceType, imageAlia
return false
}
return true
}, limitOverCommit)
}, limitOverCommit, l)

for _, i := range s {
if err := allocateInstance(*i.Host, i.InstanceName, runnerName); err != nil {
Expand Down
8 changes: 4 additions & 4 deletions server/pkg/api/server_add_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (s *ShoesLXDMultiServer) AddInstance(ctx context.Context, req *pb.AddInstan
var instanceName string

if s.poolMode {
host, instanceName, err = s.addInstancePoolMode(targetLXDHosts, req, l)
host, instanceName, err = s.addInstancePoolMode(ctx, targetLXDHosts, req, l)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -144,14 +144,14 @@ func (s *ShoesLXDMultiServer) addInstanceCreateMode(ctx context.Context, targetL
return host, instanceName, nil
}

func (s *ShoesLXDMultiServer) addInstancePoolMode(targets []lxdclient.LXDHost, req *pb.AddInstanceRequest, l *slog.Logger) (*lxdclient.LXDHost, string, error) {
host, instanceName, found := findInstanceByJob(targets, req.RunnerName)
func (s *ShoesLXDMultiServer) addInstancePoolMode(ctx context.Context, targets []lxdclient.LXDHost, req *pb.AddInstanceRequest, l *slog.Logger) (*lxdclient.LXDHost, string, error) {
host, instanceName, found := findInstanceByJob(ctx, targets, req.RunnerName, l)
if !found {
resourceTypeName := datastore.UnmarshalResourceTypePb(req.ResourceType).String()
retried := 0
for {
var err error
host, instanceName, err = allocatePooledInstance(targets, resourceTypeName, req.ImageAlias, s.overCommitPercent, req.RunnerName)
host, instanceName, err = allocatePooledInstance(ctx, targets, resourceTypeName, req.ImageAlias, s.overCommitPercent, req.RunnerName, l)
if err != nil {
if retried < 10 {
retried++
Expand Down

0 comments on commit e214990

Please sign in to comment.