From 8109917301319b54da79ce750a2f55b020b20793 Mon Sep 17 00:00:00 2001 From: Jorropo Date: Fri, 12 Jan 2024 10:07:46 +0100 Subject: [PATCH] blockservice: add `NewSessionContext` and `EmbedSessionInContext` This also include cleanup for session code. --- CHANGELOG.md | 2 + blockservice/blockservice.go | 237 ++++++++++++++++-------------- blockservice/blockservice_test.go | 65 ++++++++ 3 files changed, 195 insertions(+), 109 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb3d34c3f..734c89ebf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ The following emojis are used to highlight certain changes: ### Added +- `blockservice` now has `ContextWithSession` and `EmbedSessionInContext` functions, which allows to embed a session in a context. Future calls to `BlockGetter.GetBlock`, `BlockGetter.GetBlocks` and `NewSession` will use the session in the context. + ### Changed ### Removed diff --git a/blockservice/blockservice.go b/blockservice/blockservice.go index 423697d87..aa68b8c0f 100644 --- a/blockservice/blockservice.go +++ b/blockservice/blockservice.go @@ -144,29 +144,19 @@ func (s *blockService) Allowlist() verifcid.Allowlist { // If the current exchange is a SessionExchange, a new exchange // session will be created. Otherwise, the current exchange will be used // directly. +// Sessions are lazily setup, this is cheap. func NewSession(ctx context.Context, bs BlockService) *Session { - allowlist := verifcid.Allowlist(verifcid.DefaultAllowlist) + ses := grabSessionFromContext(ctx, bs) + if ses != nil { + return ses + } + + var allowlist verifcid.Allowlist = verifcid.DefaultAllowlist if bbs, ok := bs.(BoundedBlockService); ok { allowlist = bbs.Allowlist() } - exch := bs.Exchange() - if sessEx, ok := exch.(exchange.SessionExchange); ok { - return &Session{ - allowlist: allowlist, - sessCtx: ctx, - ses: nil, - sessEx: sessEx, - bs: bs.Blockstore(), - notifier: exch, - } - } - return &Session{ - allowlist: allowlist, - ses: exch, - sessCtx: ctx, - bs: bs.Blockstore(), - notifier: exch, - } + + return &Session{bs: bs, allowlist: allowlist, sesctx: ctx} } // AddBlock adds a particular block to the service, Putting it into the datastore. @@ -248,75 +238,80 @@ func (s *blockService) AddBlocks(ctx context.Context, bs []blocks.Block) error { // GetBlock retrieves a particular block from the service, // Getting it from the datastore using the key (hash). func (s *blockService) GetBlock(ctx context.Context, c cid.Cid) (blocks.Block, error) { + if ses := grabSessionFromContext(ctx, s); ses != nil { + return ses.GetBlock(ctx, c) + } + ctx, span := internal.StartSpan(ctx, "blockService.GetBlock", trace.WithAttributes(attribute.Stringer("CID", c))) defer span.End() - var f func() notifiableFetcher - if s.exchange != nil { - f = s.getExchange - } - - return getBlock(ctx, c, s.blockstore, s.allowlist, f) + return getBlock(ctx, c, s, s.allowlist, s.getExchangeFetcher) } -func (s *blockService) getExchange() notifiableFetcher { +// Look at what I have to do, no interface covariance :'( +func (s *blockService) getExchangeFetcher() exchange.Fetcher { return s.exchange } -func getBlock(ctx context.Context, c cid.Cid, bs blockstore.Blockstore, allowlist verifcid.Allowlist, fget func() notifiableFetcher) (blocks.Block, error) { +func getBlock(ctx context.Context, c cid.Cid, bs BlockService, allowlist verifcid.Allowlist, fetchFactory func() exchange.Fetcher) (blocks.Block, error) { err := verifcid.ValidateCid(allowlist, c) // hash security if err != nil { return nil, err } - block, err := bs.Get(ctx, c) - if err == nil { + blockstore := bs.Blockstore() + + block, err := blockstore.Get(ctx, c) + switch { + case err == nil: return block, nil + case ipld.IsNotFound(err): + break + default: + return nil, err } - if ipld.IsNotFound(err) && fget != nil { - f := fget() // Don't load the exchange until we have to + fetch := fetchFactory() // lazily create session if needed + if fetch == nil { + logger.Debug("BlockService GetBlock: Not found") + return nil, err + } - // TODO be careful checking ErrNotFound. If the underlying - // implementation changes, this will break. - logger.Debug("BlockService: Searching") - blk, err := f.GetBlock(ctx, c) - if err != nil { - return nil, err - } - // also write in the blockstore for caching, inform the exchange that the block is available - err = bs.Put(ctx, blk) - if err != nil { - return nil, err - } - err = f.NotifyNewBlocks(ctx, blk) + logger.Debug("BlockService: Searching") + blk, err := fetch.GetBlock(ctx, c) + if err != nil { + return nil, err + } + // also write in the blockstore for caching, inform the exchange that the block is available + err = blockstore.Put(ctx, blk) + if err != nil { + return nil, err + } + if ex := bs.Exchange(); ex != nil { + err = ex.NotifyNewBlocks(ctx, blk) if err != nil { return nil, err } - logger.Debugf("BlockService.BlockFetched %s", c) - return blk, nil } - - logger.Debug("BlockService GetBlock: Not found") - return nil, err + logger.Debugf("BlockService.BlockFetched %s", c) + return blk, nil } // GetBlocks gets a list of blocks asynchronously and returns through // the returned channel. // NB: No guarantees are made about order. func (s *blockService) GetBlocks(ctx context.Context, ks []cid.Cid) <-chan blocks.Block { + if ses := grabSessionFromContext(ctx, s); ses != nil { + return ses.GetBlocks(ctx, ks) + } + ctx, span := internal.StartSpan(ctx, "blockService.GetBlocks") defer span.End() - var f func() notifiableFetcher - if s.exchange != nil { - f = s.getExchange - } - - return getBlocks(ctx, ks, s.blockstore, s.allowlist, f) + return getBlocks(ctx, ks, s, s.allowlist, s.getExchangeFetcher) } -func getBlocks(ctx context.Context, ks []cid.Cid, bs blockstore.Blockstore, allowlist verifcid.Allowlist, fget func() notifiableFetcher) <-chan blocks.Block { +func getBlocks(ctx context.Context, ks []cid.Cid, blockservice BlockService, allowlist verifcid.Allowlist, fetchFactory func() exchange.Fetcher) <-chan blocks.Block { out := make(chan blocks.Block) go func() { @@ -344,6 +339,8 @@ func getBlocks(ctx context.Context, ks []cid.Cid, bs blockstore.Blockstore, allo ks = ks2 } + bs := blockservice.Blockstore() + var misses []cid.Cid for _, c := range ks { hit, err := bs.Get(ctx, c) @@ -358,17 +355,18 @@ func getBlocks(ctx context.Context, ks []cid.Cid, bs blockstore.Blockstore, allo } } - if len(misses) == 0 || fget == nil { + fetch := fetchFactory() // don't load exchange unless we have to + if len(misses) == 0 || fetch == nil { return } - f := fget() // don't load exchange unless we have to - rblocks, err := f.GetBlocks(ctx, misses) + rblocks, err := fetch.GetBlocks(ctx, misses) if err != nil { logger.Debugf("Error with GetBlocks: %s", err) return } + ex := blockservice.Exchange() var cache [1]blocks.Block // preallocate once for all iterations for { var b blocks.Block @@ -389,14 +387,16 @@ func getBlocks(ctx context.Context, ks []cid.Cid, bs blockstore.Blockstore, allo return } - // inform the exchange that the blocks are available - cache[0] = b - err = f.NotifyNewBlocks(ctx, cache[:]...) - if err != nil { - logger.Errorf("could not tell the exchange about new blocks: %s", err) - return + if ex != nil { + // inform the exchange that the blocks are available + cache[0] = b + err = ex.NotifyNewBlocks(ctx, cache[:]...) + if err != nil { + logger.Errorf("could not tell the exchange about new blocks: %s", err) + return + } + cache[0] = nil // early gc } - cache[0] = nil // early gc select { case out <- b: @@ -428,54 +428,36 @@ func (s *blockService) Close() error { return s.exchange.Close() } -type notifier interface { - NotifyNewBlocks(context.Context, ...blocks.Block) error -} - // Session is a helper type to provide higher level access to bitswap sessions type Session struct { - allowlist verifcid.Allowlist - bs blockstore.Blockstore - ses exchange.Fetcher - sessEx exchange.SessionExchange - sessCtx context.Context - notifier notifier - lk sync.Mutex + createSession sync.Once + bs BlockService + ses exchange.Fetcher + sesctx context.Context + allowlist verifcid.Allowlist } -type notifiableFetcher interface { - exchange.Fetcher - notifier -} +// grabSession is used to lazily create sessions. +func (s *Session) grabSession() exchange.Fetcher { + s.createSession.Do(func() { + defer func() { + s.sesctx = nil // early gc + }() -type notifiableFetcherWrapper struct { - exchange.Fetcher - notifier -} - -func (s *Session) getSession() notifiableFetcher { - s.lk.Lock() - defer s.lk.Unlock() - if s.ses == nil { - s.ses = s.sessEx.NewSession(s.sessCtx) - } - - return notifiableFetcherWrapper{s.ses, s.notifier} -} + ex := s.bs.Exchange() + if ex == nil { + return + } + s.ses = ex // always fallback to non session fetches -func (s *Session) getExchange() notifiableFetcher { - return notifiableFetcherWrapper{s.ses, s.notifier} -} + sesEx, ok := ex.(exchange.SessionExchange) + if !ok { + return + } + s.ses = sesEx.NewSession(s.sesctx) + }) -func (s *Session) getFetcherFactory() func() notifiableFetcher { - if s.sessEx != nil { - return s.getSession - } - if s.ses != nil { - // Our exchange isn't session compatible, let's fallback to non sessions fetches - return s.getExchange - } - return nil + return s.ses } // GetBlock gets a block in the context of a request session @@ -483,7 +465,7 @@ func (s *Session) GetBlock(ctx context.Context, c cid.Cid) (blocks.Block, error) ctx, span := internal.StartSpan(ctx, "Session.GetBlock", trace.WithAttributes(attribute.Stringer("CID", c))) defer span.End() - return getBlock(ctx, c, s.bs, s.allowlist, s.getFetcherFactory()) + return getBlock(ctx, c, s.bs, s.allowlist, s.grabSession) } // GetBlocks gets blocks in the context of a request session @@ -491,7 +473,44 @@ func (s *Session) GetBlocks(ctx context.Context, ks []cid.Cid) <-chan blocks.Blo ctx, span := internal.StartSpan(ctx, "Session.GetBlocks") defer span.End() - return getBlocks(ctx, ks, s.bs, s.allowlist, s.getFetcherFactory()) + return getBlocks(ctx, ks, s.bs, s.allowlist, s.grabSession) } var _ BlockGetter = (*Session)(nil) + +// ContextWithSession is a helper which creates a context with an embded session, +// future calls to [BlockGetter.GetBlock], [BlockGetter.GetBlocks] and [NewSession] with the same [BlockService] +// will be redirected to this same session instead. +// Sessions are lazily setup, this is cheap. +// It wont make a new session if one exists already in the context. +func ContextWithSession(ctx context.Context, bs BlockService) context.Context { + if grabSessionFromContext(ctx, bs) != nil { + return ctx + } + return EmbedSessionInContext(ctx, NewSession(ctx, bs)) +} + +// EmbedSessionInContext is like [NewSessionContext] but it allows to embed an existing session. +func EmbedSessionInContext(ctx context.Context, ses *Session) context.Context { + // use ses.bs as a key, so if multiple blockservices use embeded sessions it gets dispatched to the matching blockservice. + return context.WithValue(ctx, ses.bs, ses) +} + +// grabSessionFromContext returns nil if the session was not found +// This is a private API on purposes, I dislike when consumers tradeoff compiletime typesafety with runtime typesafety, +// if this API is public it is too easy to forget to pass a [BlockService] or [Session] object around in your app. +// By having this private we allow consumers to follow the trace of where the blockservice is passed and used. +func grabSessionFromContext(ctx context.Context, bs BlockService) *Session { + s := ctx.Value(bs) + if s == nil { + return nil + } + + ss, ok := s.(*Session) + if !ok { + // idk what to do here, that kinda sucks, giveup + return nil + } + + return ss +} diff --git a/blockservice/blockservice_test.go b/blockservice/blockservice_test.go index e36058040..6591529d2 100644 --- a/blockservice/blockservice_test.go +++ b/blockservice/blockservice_test.go @@ -288,3 +288,68 @@ func TestAllowlist(t *testing.T) { check(blockservice.GetBlock) check(NewSession(ctx, blockservice).GetBlock) } + +type fakeIsNewSessionCreateExchange struct { + ses exchange.Fetcher + newSessionWasCalled bool +} + +var _ exchange.SessionExchange = (*fakeIsNewSessionCreateExchange)(nil) + +func (*fakeIsNewSessionCreateExchange) Close() error { + return nil +} + +func (*fakeIsNewSessionCreateExchange) GetBlock(context.Context, cid.Cid) (blocks.Block, error) { + panic("should call on the session") +} + +func (*fakeIsNewSessionCreateExchange) GetBlocks(context.Context, []cid.Cid) (<-chan blocks.Block, error) { + panic("should call on the session") +} + +func (f *fakeIsNewSessionCreateExchange) NewSession(context.Context) exchange.Fetcher { + f.newSessionWasCalled = true + return f.ses +} + +func (*fakeIsNewSessionCreateExchange) NotifyNewBlocks(context.Context, ...blocks.Block) error { + return nil +} + +func TestContextSession(t *testing.T) { + t.Parallel() + a := assert.New(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + bgen := butil.NewBlockGenerator() + block1 := bgen.Next() + block2 := bgen.Next() + + bs := blockstore.NewBlockstore(ds.NewMapDatastore()) + a.NoError(bs.Put(ctx, block1)) + a.NoError(bs.Put(ctx, block2)) + sesEx := &fakeIsNewSessionCreateExchange{ses: offline.Exchange(bs)} + + service := New(blockstore.NewBlockstore(ds.NewMapDatastore()), sesEx) + + ctx = ContextWithSession(ctx, service) + + b, err := service.GetBlock(ctx, block1.Cid()) + a.NoError(err) + a.Equal(b.RawData(), block1.RawData()) + a.True(sesEx.newSessionWasCalled, "new session from context should be created") + sesEx.newSessionWasCalled = false + + bchan := service.GetBlocks(ctx, []cid.Cid{block2.Cid()}) + a.Equal((<-bchan).RawData(), block2.RawData()) + a.False(sesEx.newSessionWasCalled, "session should be reused in context") + + a.Equal( + NewSession(ctx, service), + NewSession(ContextWithSession(ctx, service), service), + "session must be deduped in all invocations on the same context", + ) +}