From 80b274dc7dbd0c9b5dfa90bc8b3059333c154ed4 Mon Sep 17 00:00:00 2001 From: Preetam Dwivedi Date: Mon, 8 Jun 2026 11:48:44 -0700 Subject: [PATCH 1/3] refactor(mergechecker): accept entity.Request, resolve change internally MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change MergeChecker.Check to take the orchestrator's request identity (entity.Request) instead of a controller-pre-resolved entity.Change, per the extension contract. The GitHub implementation and the fake read request.Change themselves; the validate controller hands over the request it already loaded. Output is unchanged (mergechecker.Result). The factory and Config are unchanged — no dependency injection is needed since the checker resolves nothing beyond the change already on the request. --- submitqueue/extension/mergechecker/fake/fake.go | 6 +++--- submitqueue/extension/mergechecker/fake/fake_test.go | 2 +- submitqueue/extension/mergechecker/github/checker.go | 6 ++++-- submitqueue/extension/mergechecker/github/checker_test.go | 2 +- submitqueue/extension/mergechecker/mergechecker.go | 7 ++++--- .../extension/mergechecker/mock/mergechecker_mock.go | 8 ++++---- submitqueue/orchestrator/controller/validate/validate.go | 2 +- 7 files changed, 18 insertions(+), 15 deletions(-) diff --git a/submitqueue/extension/mergechecker/fake/fake.go b/submitqueue/extension/mergechecker/fake/fake.go index 1fd6867f..c66165ab 100644 --- a/submitqueue/extension/mergechecker/fake/fake.go +++ b/submitqueue/extension/mergechecker/fake/fake.go @@ -52,9 +52,9 @@ func New() mergechecker.MergeChecker { } // Check reports the change as mergeable unless a recognized marker token is -// present in one of its URIs. -func (checker) Check(_ context.Context, change entity.Change) (mergechecker.Result, error) { - switch fakemarker.Token(change.URIs) { +// present in one of the request's change URIs. +func (checker) Check(_ context.Context, request entity.Request) (mergechecker.Result, error) { + switch fakemarker.Token(request.Change.URIs) { case tokenUnmergeable: return mergechecker.Result{Mergeable: false, Reason: "fake: marked unmergeable"}, nil case tokenError: diff --git a/submitqueue/extension/mergechecker/fake/fake_test.go b/submitqueue/extension/mergechecker/fake/fake_test.go index f55a9d17..142b8dc1 100644 --- a/submitqueue/extension/mergechecker/fake/fake_test.go +++ b/submitqueue/extension/mergechecker/fake/fake_test.go @@ -66,7 +66,7 @@ func TestChecker_Check(t *testing.T) { c := New() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - res, err := c.Check(context.Background(), entity.Change{URIs: tt.uris}) + res, err := c.Check(context.Background(), entity.Request{Change: entity.Change{URIs: tt.uris}}) if tt.wantErr { require.Error(t, err) return diff --git a/submitqueue/extension/mergechecker/github/checker.go b/submitqueue/extension/mergechecker/github/checker.go index 7fb2dc92..4a5134aa 100644 --- a/submitqueue/extension/mergechecker/github/checker.go +++ b/submitqueue/extension/mergechecker/github/checker.go @@ -60,13 +60,15 @@ func NewMergeChecker(params Params) mergechecker.MergeChecker { } } -// Check assesses whether a change can merge cleanly using the GitHub GraphQL API. -func (c *mergeChecker) Check(ctx context.Context, change entity.Change) (result mergechecker.Result, retErr error) { +// Check assesses whether a request's change can merge cleanly using the GitHub GraphQL API. +func (c *mergeChecker) Check(ctx context.Context, request entity.Request) (result mergechecker.Result, retErr error) { const opName = "check" op := metrics.Begin(c.metricsScope, opName) defer func() { op.Complete(retErr) }() + change := request.Change + // Parse all change IDs // TODO: classify parse errors as user errors (non-retryable) vs system errors. changeIDs := make([]entitygithub.ChangeID, 0, len(change.URIs)) diff --git a/submitqueue/extension/mergechecker/github/checker_test.go b/submitqueue/extension/mergechecker/github/checker_test.go index f1089a4b..7369d65d 100644 --- a/submitqueue/extension/mergechecker/github/checker_test.go +++ b/submitqueue/extension/mergechecker/github/checker_test.go @@ -156,7 +156,7 @@ func TestMergeChecker_Check(t *testing.T) { defer server.Close() mc := newTestMergeChecker(t, server.URL) - result, err := mc.Check(context.Background(), tt.change) + result, err := mc.Check(context.Background(), entity.Request{Change: tt.change}) if tt.wantErr { require.Error(t, err) return diff --git a/submitqueue/extension/mergechecker/mergechecker.go b/submitqueue/extension/mergechecker/mergechecker.go index 50ef5d09..28c95d8e 100644 --- a/submitqueue/extension/mergechecker/mergechecker.go +++ b/submitqueue/extension/mergechecker/mergechecker.go @@ -22,12 +22,13 @@ import ( "github.com/uber/submitqueue/submitqueue/entity" ) -// MergeChecker predicts whether a set of changes can merge cleanly. +// MergeChecker predicts whether a request's changes can merge cleanly. type MergeChecker interface { // Check is a fail-fast mergeability check that optimistically assesses - // whether the changes can be merged. A positive result does not + // whether the request's changes can be merged. It is handed the request + // identity and reads request.Change itself. A positive result does not // guarantee that the changes will apply cleanly at merge time. - Check(ctx context.Context, change entity.Change) (Result, error) + Check(ctx context.Context, request entity.Request) (Result, error) } // Result holds the outcome of a mergeability check. diff --git a/submitqueue/extension/mergechecker/mock/mergechecker_mock.go b/submitqueue/extension/mergechecker/mock/mergechecker_mock.go index eab22c58..08f28f8c 100644 --- a/submitqueue/extension/mergechecker/mock/mergechecker_mock.go +++ b/submitqueue/extension/mergechecker/mock/mergechecker_mock.go @@ -43,18 +43,18 @@ func (m *MockMergeChecker) EXPECT() *MockMergeCheckerMockRecorder { } // Check mocks base method. -func (m *MockMergeChecker) Check(ctx context.Context, change entity.Change) (mergechecker.Result, error) { +func (m *MockMergeChecker) Check(ctx context.Context, request entity.Request) (mergechecker.Result, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Check", ctx, change) + ret := m.ctrl.Call(m, "Check", ctx, request) ret0, _ := ret[0].(mergechecker.Result) ret1, _ := ret[1].(error) return ret0, ret1 } // Check indicates an expected call of Check. -func (mr *MockMergeCheckerMockRecorder) Check(ctx, change any) *gomock.Call { +func (mr *MockMergeCheckerMockRecorder) Check(ctx, request any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Check", reflect.TypeOf((*MockMergeChecker)(nil).Check), ctx, change) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Check", reflect.TypeOf((*MockMergeChecker)(nil).Check), ctx, request) } // MockFactory is a mock of Factory interface. diff --git a/submitqueue/orchestrator/controller/validate/validate.go b/submitqueue/orchestrator/controller/validate/validate.go index 5a44309f..123d0556 100644 --- a/submitqueue/orchestrator/controller/validate/validate.go +++ b/submitqueue/orchestrator/controller/validate/validate.go @@ -140,7 +140,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) (r coremetrics.NamedCounter(c.metricsScope, "process", "merge_check_errors", 1) return fmt.Errorf("failed to build merge checker for queue %s: %w", request.Queue, err) } - mergeResult, err := mergeChecker.Check(ctx, request.Change) + mergeResult, err := mergeChecker.Check(ctx, request) if err != nil { coremetrics.NamedCounter(c.metricsScope, "process", "merge_check_errors", 1) return fmt.Errorf("merge check failed: %w", err) From 985169f7c70bc0cbd54c12f78d628b61ccf086b9 Mon Sep 17 00:00:00 2001 From: Preetam Dwivedi Date: Mon, 8 Jun 2026 11:52:17 -0700 Subject: [PATCH 2/3] refactor(changeprovider): accept entity.Request, resolve change internally MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change ChangeProvider.Get to take the orchestrator's request identity (entity.Request) instead of a controller-pre-resolved entity.Change, per the extension contract. The GitHub implementation and the fake read request.Change themselves; the validate controller hands over the request it already loaded. Output is unchanged: one entity.ChangeInfo per URI, each self-identifying by URI. The provider is the external resolver, so it needs no injected dependency — the factory and Config are unchanged. --- .../extension/changeprovider/change_provider.go | 5 +++-- submitqueue/extension/changeprovider/fake/fake.go | 7 ++++--- .../extension/changeprovider/fake/fake_test.go | 6 +++--- .../extension/changeprovider/github/provider.go | 6 ++++-- .../changeprovider/github/provider_test.go | 14 +++++++------- .../changeprovider/mock/change_provider_mock.go | 8 ++++---- .../orchestrator/controller/validate/validate.go | 2 +- .../controller/validate/validate_test.go | 2 +- 8 files changed, 27 insertions(+), 23 deletions(-) diff --git a/submitqueue/extension/changeprovider/change_provider.go b/submitqueue/extension/changeprovider/change_provider.go index 9b4b7433..ba55b8bf 100644 --- a/submitqueue/extension/changeprovider/change_provider.go +++ b/submitqueue/extension/changeprovider/change_provider.go @@ -29,10 +29,11 @@ import ( // entity.Author, entity.ChangedFile — live in the entity package so the same typed // facts can be persisted (entity.ChangeRecord) and scored without re-declaration. type ChangeProvider interface { - // Get retrieves change information for the provided Change. + // Get retrieves change information for the provided request. + // It is handed the request identity and reads request.Change itself. // For a Change with multiple URIs (e.g., stacked PRs), returns one ChangeInfo per URI. // Returns a slice of ChangeInfo, one for each change in the stack. - Get(ctx context.Context, change entity.Change) ([]entity.ChangeInfo, error) + Get(ctx context.Context, request entity.Request) ([]entity.ChangeInfo, error) } // Config carries the per-queue identity handed to a Factory. The system knows diff --git a/submitqueue/extension/changeprovider/fake/fake.go b/submitqueue/extension/changeprovider/fake/fake.go index 6d829c3a..1e30998e 100644 --- a/submitqueue/extension/changeprovider/fake/fake.go +++ b/submitqueue/extension/changeprovider/fake/fake.go @@ -47,9 +47,10 @@ func New() changeprovider.ChangeProvider { return provider{} } -// Get returns one ChangeInfo per URI in the change, unless a recognized marker -// token requests a failure. The "one ChangeInfo per URI" contract is preserved. -func (provider) Get(_ context.Context, change entity.Change) ([]entity.ChangeInfo, error) { +// Get returns one ChangeInfo per URI in the request's change, unless a recognized +// marker token requests a failure. The "one ChangeInfo per URI" contract is preserved. +func (provider) Get(_ context.Context, request entity.Request) ([]entity.ChangeInfo, error) { + change := request.Change if fakemarker.Token(change.URIs) == tokenError { return nil, fmt.Errorf("fake: marked provider error") } diff --git a/submitqueue/extension/changeprovider/fake/fake_test.go b/submitqueue/extension/changeprovider/fake/fake_test.go index bb7c40db..a0a4fb2f 100644 --- a/submitqueue/extension/changeprovider/fake/fake_test.go +++ b/submitqueue/extension/changeprovider/fake/fake_test.go @@ -47,7 +47,7 @@ func TestProvider_Get_OnePerURI(t *testing.T) { p := New() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - infos, err := p.Get(context.Background(), entity.Change{URIs: tt.uris}) + infos, err := p.Get(context.Background(), entity.Request{Change: entity.Change{URIs: tt.uris}}) require.NoError(t, err) require.Len(t, infos, len(tt.uris)) for i, uri := range tt.uris { @@ -59,8 +59,8 @@ func TestProvider_Get_OnePerURI(t *testing.T) { func TestProvider_Get_ErrorMarker(t *testing.T) { p := New() - _, err := p.Get(context.Background(), entity.Change{ + _, err := p.Get(context.Background(), entity.Request{Change: entity.Change{ URIs: []string{"github://owner/repo/pull/1/abc?sq-fake=provider-error"}, - }) + }}) require.Error(t, err) } diff --git a/submitqueue/extension/changeprovider/github/provider.go b/submitqueue/extension/changeprovider/github/provider.go index 15daaa6a..5b1bcf89 100644 --- a/submitqueue/extension/changeprovider/github/provider.go +++ b/submitqueue/extension/changeprovider/github/provider.go @@ -42,12 +42,14 @@ func NewProvider(params Params) changeprovider.ChangeProvider { } } -// Get retrieves change information from GitHub for the provided Change. +// Get retrieves change information from GitHub for the request's change. // Returns one ChangeInfo per URI (one per PR in stacked changes). -func (p *provider) Get(ctx context.Context, change entity.Change) (_ []entity.ChangeInfo, retErr error) { +func (p *provider) Get(ctx context.Context, request entity.Request) (_ []entity.ChangeInfo, retErr error) { op := coremetrics.Begin(p.metricsScope, "get") defer func() { op.Complete(retErr) }() + change := request.Change + // Parse all change IDs changeIDs := make([]entitygithub.ChangeID, 0, len(change.URIs)) for _, uri := range change.URIs { diff --git a/submitqueue/extension/changeprovider/github/provider_test.go b/submitqueue/extension/changeprovider/github/provider_test.go index def8e767..7907ec98 100644 --- a/submitqueue/extension/changeprovider/github/provider_test.go +++ b/submitqueue/extension/changeprovider/github/provider_test.go @@ -106,7 +106,7 @@ func TestProvider_Get(t *testing.T) { } p := newTestProvider(t, serverURL) - infos, err := p.Get(context.Background(), entity.Change{URIs: tt.uris}) + infos, err := p.Get(context.Background(), entity.Request{Change: entity.Change{URIs: tt.uris}}) if tt.wantErr { require.Error(t, err) @@ -147,9 +147,9 @@ func TestProvider_Get_Pagination(t *testing.T) { defer server.Close() p := newTestProvider(t, server.URL) - infos, err := p.Get(context.Background(), entity.Change{ + infos, err := p.Get(context.Background(), entity.Request{Change: entity.Change{ URIs: []string{"github://uber/submitqueue/pull/456/" + shaXYZ}, - }) + }}) require.NoError(t, err) assert.Equal(t, 2, callCount) @@ -170,12 +170,12 @@ func TestProvider_Get_MultiplePRs(t *testing.T) { defer server.Close() p := newTestProvider(t, server.URL) - infos, err := p.Get(context.Background(), entity.Change{ + infos, err := p.Get(context.Background(), entity.Request{Change: entity.Change{ URIs: []string{ "github://uber/submitqueue/pull/123/" + shaA, "github://uber/submitqueue/pull/456/" + shaB, }, - }) + }}) require.NoError(t, err) assert.Equal(t, 2, callCount) @@ -202,12 +202,12 @@ func TestProvider_Get_FetchError_StopsOnFirstFailure(t *testing.T) { defer server.Close() p := newTestProvider(t, server.URL) - _, err := p.Get(context.Background(), entity.Change{ + _, err := p.Get(context.Background(), entity.Request{Change: entity.Change{ URIs: []string{ "github://uber/submitqueue/pull/123/" + shaA, "github://uber/submitqueue/pull/456/" + shaB, }, - }) + }}) require.Error(t, err) assert.Equal(t, 2, callCount) diff --git a/submitqueue/extension/changeprovider/mock/change_provider_mock.go b/submitqueue/extension/changeprovider/mock/change_provider_mock.go index a1afc4ae..4440e6e8 100644 --- a/submitqueue/extension/changeprovider/mock/change_provider_mock.go +++ b/submitqueue/extension/changeprovider/mock/change_provider_mock.go @@ -43,18 +43,18 @@ func (m *MockChangeProvider) EXPECT() *MockChangeProviderMockRecorder { } // Get mocks base method. -func (m *MockChangeProvider) Get(ctx context.Context, change entity.Change) ([]entity.ChangeInfo, error) { +func (m *MockChangeProvider) Get(ctx context.Context, request entity.Request) ([]entity.ChangeInfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", ctx, change) + ret := m.ctrl.Call(m, "Get", ctx, request) ret0, _ := ret[0].([]entity.ChangeInfo) ret1, _ := ret[1].(error) return ret0, ret1 } // Get indicates an expected call of Get. -func (mr *MockChangeProviderMockRecorder) Get(ctx, change any) *gomock.Call { +func (mr *MockChangeProviderMockRecorder) Get(ctx, request any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockChangeProvider)(nil).Get), ctx, change) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockChangeProvider)(nil).Get), ctx, request) } // MockFactory is a mock of Factory interface. diff --git a/submitqueue/orchestrator/controller/validate/validate.go b/submitqueue/orchestrator/controller/validate/validate.go index 123d0556..ad7d1c31 100644 --- a/submitqueue/orchestrator/controller/validate/validate.go +++ b/submitqueue/orchestrator/controller/validate/validate.go @@ -161,7 +161,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) (r coremetrics.NamedCounter(c.metricsScope, "process", "change_provider_errors", 1) return fmt.Errorf("failed to build change provider for queue %s: %w", request.Queue, err) } - changeInfos, err := changeProvider.Get(ctx, request.Change) + changeInfos, err := changeProvider.Get(ctx, request) if err != nil { coremetrics.NamedCounter(c.metricsScope, "process", "change_provider_errors", 1) return fmt.Errorf("failed to fetch change information for request %s: %w", request.ID, err) diff --git a/submitqueue/orchestrator/controller/validate/validate_test.go b/submitqueue/orchestrator/controller/validate/validate_test.go index 20f4804f..90e98939 100644 --- a/submitqueue/orchestrator/controller/validate/validate_test.go +++ b/submitqueue/orchestrator/controller/validate/validate_test.go @@ -46,7 +46,7 @@ func requestIDPayload(t *testing.T, id string) []byte { // mockChangeProvider is a simple mock that returns test data. type mockChangeProvider struct{} -func (m *mockChangeProvider) Get(ctx context.Context, change entity.Change) ([]entity.ChangeInfo, error) { +func (m *mockChangeProvider) Get(ctx context.Context, request entity.Request) ([]entity.ChangeInfo, error) { return []entity.ChangeInfo{ { URI: "github://org/repo/pull/123/abcdef0123456789abcdef0123456789abcdef01", From c0be66cf9c0c804cf73f80f67117cc3057075701 Mon Sep 17 00:00:00 2001 From: Preetam Dwivedi Date: Mon, 8 Jun 2026 12:00:43 -0700 Subject: [PATCH 3/3] refactor(scorer): score entity.Batch, resolve changes internally MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change Scorer.Score to take the batch identity (entity.Batch) instead of a controller-pre-resolved entity.BatchChanges, per the extension contract. The score controller drops its private collectBatchChanges walk and just hands the batch to the scorer. The heuristic scorer and the fake gain an injected changeset.Resolver and call DetailedForBatch to resolve the batch's changes themselves; the composite scorer delegates the batch to its children unchanged. The wiring constructs one resolver from the request and change stores and injects it into every scorer it builds. Output is unchanged (a single float64 score per batch). The scorer factory and Config are unchanged — the resolver is injected at construction. --- .../orchestrator/server/BUILD.bazel | 1 + .../submitqueue/orchestrator/server/main.go | 17 ++++--- .../extension/scorer/composite/scorer.go | 8 ++-- .../extension/scorer/composite/scorer_test.go | 10 ++--- submitqueue/extension/scorer/fake/BUILD.bazel | 3 ++ submitqueue/extension/scorer/fake/fake.go | 21 ++++++--- .../extension/scorer/fake/fake_test.go | 40 ++++++++++------- .../extension/scorer/heuristic/BUILD.bazel | 2 + .../extension/scorer/heuristic/scorer.go | 19 +++++--- .../extension/scorer/heuristic/scorer_test.go | 11 ++--- .../extension/scorer/mock/scorer_mock.go | 8 ++-- submitqueue/extension/scorer/scorer.go | 7 +-- .../orchestrator/controller/score/score.go | 44 +++---------------- .../controller/score/score_test.go | 38 ++++------------ 14 files changed, 105 insertions(+), 124 deletions(-) diff --git a/example/submitqueue/orchestrator/server/BUILD.bazel b/example/submitqueue/orchestrator/server/BUILD.bazel index d3a16bee..3dcb998a 100644 --- a/example/submitqueue/orchestrator/server/BUILD.bazel +++ b/example/submitqueue/orchestrator/server/BUILD.bazel @@ -19,6 +19,7 @@ go_library( "//extension/counter/mysql", "//extension/messagequeue", "//extension/messagequeue/mysql", + "//submitqueue/core/changeset", "//submitqueue/core/consumer", "//submitqueue/entity", "//submitqueue/extension/buildrunner", diff --git a/example/submitqueue/orchestrator/server/main.go b/example/submitqueue/orchestrator/server/main.go index 91e0d619..2c5411cb 100644 --- a/example/submitqueue/orchestrator/server/main.go +++ b/example/submitqueue/orchestrator/server/main.go @@ -38,6 +38,7 @@ import ( mysqlcounter "github.com/uber/submitqueue/extension/counter/mysql" extqueue "github.com/uber/submitqueue/extension/messagequeue" queueMySQL "github.com/uber/submitqueue/extension/messagequeue/mysql" + "github.com/uber/submitqueue/submitqueue/core/changeset" "github.com/uber/submitqueue/submitqueue/core/consumer" "github.com/uber/submitqueue/submitqueue/entity" "github.com/uber/submitqueue/submitqueue/extension/buildrunner" @@ -230,7 +231,7 @@ func run() error { // back to a baseline profile for queues without an explicit entry. This is // the single place queue topology is known; the extension packages stay // queue-agnostic. - queues, err := newQueueRegistry(logger, scope) + queues, err := newQueueRegistry(logger, scope, changeset.New(store.GetRequestStore(), store.GetChangeStore())) if err != nil { return fmt.Errorf("failed to build queue registry: %w", err) } @@ -796,7 +797,7 @@ func newPusher(logger *zap.Logger, scope tally.Scope) (pusher.Pusher, error) { // conflict analyzer. Queues without an explicit profile fall back to the // baseline. This is the one place queue topology lives; extension packages stay // queue-agnostic. -func newQueueRegistry(logger *zap.Logger, scope tally.Scope) (queueRegistry, error) { +func newQueueRegistry(logger *zap.Logger, scope tally.Scope, resolver changeset.Resolver) (queueRegistry, error) { mc, err := newMergeChecker(logger, scope) if err != nil { return queueRegistry{}, fmt.Errorf("failed to create merge checker: %w", err) @@ -833,7 +834,8 @@ func newQueueRegistry(logger *zap.Logger, scope tally.Scope) (queueRegistry, err changeProvider: cp, pusher: psh, buildRunner: buildfake.New(), - scorer: scorerfake.New(heuristic.New( + scorer: scorerfake.New(resolver, heuristic.New( + resolver, []heuristic.Bucket{{Min: 0, Max: 1<<31 - 1, Score: 0.5}}, batchLines, scope.SubScope("scorer.default"), )), @@ -845,7 +847,8 @@ func newQueueRegistry(logger *zap.Logger, scope tally.Scope) (queueRegistry, err // test-queue: bucketed heuristic scorer; conservative (serialized) conflicts // inherited from the baseline. testQueue := base - testQueue.scorer = scorerfake.New(heuristic.New( + testQueue.scorer = scorerfake.New(resolver, heuristic.New( + resolver, []heuristic.Bucket{ {Min: 0, Max: 1, Score: 0.95}, {Min: 2, Max: 5, Score: 0.80}, @@ -858,10 +861,10 @@ func newQueueRegistry(logger *zap.Logger, scope tally.Scope) (queueRegistry, err // e2e-test-queue: composite scorer; no conflicts (maximum parallelism). e2eQueue := base e2eQueue.analyzer = conflictfake.New(none.New(), nil) - e2eQueue.scorer = scorerfake.New(composite.New( + e2eQueue.scorer = scorerfake.New(resolver, composite.New( map[string]scorer.Scorer{ - "size": heuristic.New([]heuristic.Bucket{{Min: 0, Max: 1<<31 - 1, Score: 0.8}}, batchLines, scope), - "flat": heuristic.New([]heuristic.Bucket{{Min: 0, Max: 1<<31 - 1, Score: 0.6}}, batchLines, scope), + "size": heuristic.New(resolver, []heuristic.Bucket{{Min: 0, Max: 1<<31 - 1, Score: 0.8}}, batchLines, scope), + "flat": heuristic.New(resolver, []heuristic.Bucket{{Min: 0, Max: 1<<31 - 1, Score: 0.6}}, batchLines, scope), }, composite.Avg, scope.SubScope("scorer.e2e-test-queue"), )) diff --git a/submitqueue/extension/scorer/composite/scorer.go b/submitqueue/extension/scorer/composite/scorer.go index 2acce898..8d90c832 100644 --- a/submitqueue/extension/scorer/composite/scorer.go +++ b/submitqueue/extension/scorer/composite/scorer.go @@ -88,15 +88,15 @@ func New(scorers map[string]scorer.Scorer, reduce ReduceFunc, scope tally.Scope) } } -// Score evaluates all child scorers and combines their results using the reduce function. -// If any child scorer returns an error, that error is returned immediately. -func (c *compositeScorer) Score(ctx context.Context, changes entity.BatchChanges) (ret float64, retErr error) { +// Score evaluates all child scorers on the batch and combines their results using the +// reduce function. If any child scorer returns an error, that error is returned immediately. +func (c *compositeScorer) Score(ctx context.Context, batch entity.Batch) (ret float64, retErr error) { op := metrics.Begin(c.scope, "score") defer func() { op.Complete(retErr) }() scores := make(map[string]float64, len(c.scorers)) for name, s := range c.scorers { - score, err := s.Score(ctx, changes) + score, err := s.Score(ctx, batch) if err != nil { return 0, err } diff --git a/submitqueue/extension/scorer/composite/scorer_test.go b/submitqueue/extension/scorer/composite/scorer_test.go index 09e7266e..9052e547 100644 --- a/submitqueue/extension/scorer/composite/scorer_test.go +++ b/submitqueue/extension/scorer/composite/scorer_test.go @@ -31,14 +31,14 @@ type fixedScorer struct { score float64 } -func (f *fixedScorer) Score(_ context.Context, _ entity.BatchChanges) (float64, error) { +func (f *fixedScorer) Score(_ context.Context, _ entity.Batch) (float64, error) { return f.score, nil } // errorScorer always returns an error. type errorScorer struct{} -func (e *errorScorer) Score(_ context.Context, _ entity.BatchChanges) (float64, error) { +func (e *errorScorer) Score(_ context.Context, _ entity.Batch) (float64, error) { return 0, fmt.Errorf("scorer failed") } @@ -99,7 +99,7 @@ func TestScorer_Score(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := New(tt.scorers, tt.reduce, tally.NoopScope) - got, err := s.Score(context.Background(), entity.BatchChanges{}) + got, err := s.Score(context.Background(), entity.Batch{}) require.NoError(t, err) assert.InDelta(t, tt.want, got, 1e-9) }) @@ -111,7 +111,7 @@ func TestScorer_Score_ChildError(t *testing.T) { "error": &errorScorer{}, "files": &fixedScorer{0.9}, }, Min, tally.NoopScope) - _, err := s.Score(context.Background(), entity.BatchChanges{}) + _, err := s.Score(context.Background(), entity.Batch{}) require.Error(t, err) } @@ -140,7 +140,7 @@ func TestReduceFunc_ReceivesNames(t *testing.T) { "files": &fixedScorer{0.9}, "deps": &fixedScorer{0.95}, }, custom, tally.NoopScope) - got, err := s.Score(context.Background(), entity.BatchChanges{}) + got, err := s.Score(context.Background(), entity.Batch{}) require.NoError(t, err) assert.Equal(t, 0.9, got) assert.ElementsMatch(t, []string{"files", "deps"}, receivedNames) diff --git a/submitqueue/extension/scorer/fake/BUILD.bazel b/submitqueue/extension/scorer/fake/BUILD.bazel index f1335904..f5c857f2 100644 --- a/submitqueue/extension/scorer/fake/BUILD.bazel +++ b/submitqueue/extension/scorer/fake/BUILD.bazel @@ -6,6 +6,7 @@ go_library( importpath = "github.com/uber/submitqueue/submitqueue/extension/scorer/fake", visibility = ["//visibility:public"], deps = [ + "//submitqueue/core/changeset", "//submitqueue/core/fakemarker", "//submitqueue/entity", "//submitqueue/extension/scorer", @@ -17,6 +18,8 @@ go_test( srcs = ["fake_test.go"], embed = [":fake"], deps = [ + "//submitqueue/core/changeset", + "//submitqueue/core/changeset/fake", "//submitqueue/entity", "//submitqueue/extension/scorer", "//submitqueue/extension/scorer/heuristic", diff --git a/submitqueue/extension/scorer/fake/fake.go b/submitqueue/extension/scorer/fake/fake.go index fcd6219e..b5a76c78 100644 --- a/submitqueue/extension/scorer/fake/fake.go +++ b/submitqueue/extension/scorer/fake/fake.go @@ -27,6 +27,7 @@ import ( "context" "fmt" + "github.com/uber/submitqueue/submitqueue/core/changeset" "github.com/uber/submitqueue/submitqueue/core/fakemarker" "github.com/uber/submitqueue/submitqueue/entity" "github.com/uber/submitqueue/submitqueue/extension/scorer" @@ -36,25 +37,31 @@ import ( const tokenError = "score-error" // scorerFake decorates a delegate Scorer, injecting an error when a change URI -// carries the failure marker. +// carries the failure marker. It resolves the batch itself to inspect URIs. type scorerFake struct { + resolver changeset.Resolver delegate scorer.Scorer } // New returns a scorer.Scorer that delegates to the given scorer but returns an -// error when a change URI carries the "sq-fake=score-error" marker. The delegate -// is the existing scorer implementation to wrap (e.g. heuristic or composite). -func New(delegate scorer.Scorer) scorer.Scorer { - return scorerFake{delegate: delegate} +// error when a change URI carries the "sq-fake=score-error" marker. The resolver +// resolves the batch's changes so the marker can be inspected; the delegate is the +// existing scorer implementation to wrap (e.g. heuristic or composite). +func New(resolver changeset.Resolver, delegate scorer.Scorer) scorer.Scorer { + return scorerFake{resolver: resolver, delegate: delegate} } // Score returns an error when a change URI carries the failure marker; otherwise // it delegates to the wrapped scorer. -func (s scorerFake) Score(ctx context.Context, changes entity.BatchChanges) (float64, error) { +func (s scorerFake) Score(ctx context.Context, batch entity.Batch) (float64, error) { + changes, err := s.resolver.DetailedForBatch(ctx, batch) + if err != nil { + return 0, err + } if markerToken(changes) == tokenError { return 0, fmt.Errorf("fake: marked score error") } - return s.delegate.Score(ctx, changes) + return s.delegate.Score(ctx, batch) } // markerToken returns the marker token embedded in the first change URI that diff --git a/submitqueue/extension/scorer/fake/fake_test.go b/submitqueue/extension/scorer/fake/fake_test.go index 5a0529ef..7c09ed51 100644 --- a/submitqueue/extension/scorer/fake/fake_test.go +++ b/submitqueue/extension/scorer/fake/fake_test.go @@ -21,42 +21,50 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber-go/tally" + "github.com/uber/submitqueue/submitqueue/core/changeset" + changesetfake "github.com/uber/submitqueue/submitqueue/core/changeset/fake" "github.com/uber/submitqueue/submitqueue/entity" "github.com/uber/submitqueue/submitqueue/extension/scorer" "github.com/uber/submitqueue/submitqueue/extension/scorer/heuristic" ) +const batchID = "q/batch/1" + func TestNew_ImplementsInterface(t *testing.T) { - var _ scorer.Scorer = New(nil) + var _ scorer.Scorer = New(nil, nil) +} + +// resolverFor returns a changeset resolver seeded so that batchID's detailed +// changes carry the given URIs. +func resolverFor(uris ...string) changeset.Resolver { + changes := make([]entity.ChangeInfo, 0, len(uris)) + for _, u := range uris { + changes = append(changes, entity.ChangeInfo{URI: u}) + } + return changesetfake.New().SetDetailed(batchID, entity.BatchChanges{BatchID: batchID, Queue: "q", Changes: changes}) } -// delegate returns a heuristic scorer that scores every batch at want. -func delegate(t *testing.T, want float64) scorer.Scorer { - t.Helper() +// delegate returns a heuristic scorer (backed by resolver) that scores every batch at want. +func delegate(resolver changeset.Resolver, want float64) scorer.Scorer { return heuristic.New( + resolver, []heuristic.Bucket{{Min: 0, Max: 1<<31 - 1, Score: want}}, func(_ context.Context, c entity.BatchChanges) (int, error) { return len(c.Changes), nil }, tally.NoopScope, ) } -func batch(uris ...string) entity.BatchChanges { - changes := make([]entity.ChangeInfo, 0, len(uris)) - for _, u := range uris { - changes = append(changes, entity.ChangeInfo{URI: u}) - } - return entity.BatchChanges{BatchID: "q/batch/1", Queue: "q", Changes: changes} -} - func TestScore_DelegatesWhenUnmarked(t *testing.T) { - s := New(delegate(t, 0.7)) - got, err := s.Score(context.Background(), batch("github://o/r/pull/1/a")) + r := resolverFor("github://o/r/pull/1/a") + s := New(r, delegate(r, 0.7)) + got, err := s.Score(context.Background(), entity.Batch{ID: batchID}) require.NoError(t, err) assert.Equal(t, 0.7, got) } func TestScore_ErrorMarker(t *testing.T) { - s := New(delegate(t, 0.7)) - _, err := s.Score(context.Background(), batch("github://o/r/pull/1/a?sq-fake=score-error")) + r := resolverFor("github://o/r/pull/1/a?sq-fake=score-error") + s := New(r, delegate(r, 0.7)) + _, err := s.Score(context.Background(), entity.Batch{ID: batchID}) require.Error(t, err) } diff --git a/submitqueue/extension/scorer/heuristic/BUILD.bazel b/submitqueue/extension/scorer/heuristic/BUILD.bazel index 7b66d777..3f45785e 100644 --- a/submitqueue/extension/scorer/heuristic/BUILD.bazel +++ b/submitqueue/extension/scorer/heuristic/BUILD.bazel @@ -7,6 +7,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//core/metrics", + "//submitqueue/core/changeset", "//submitqueue/entity", "//submitqueue/extension/scorer", "@com_github_uber_go_tally//:tally", @@ -18,6 +19,7 @@ go_test( srcs = ["scorer_test.go"], embed = [":heuristic"], deps = [ + "//submitqueue/core/changeset/fake", "//submitqueue/entity", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", diff --git a/submitqueue/extension/scorer/heuristic/scorer.go b/submitqueue/extension/scorer/heuristic/scorer.go index b1afea1c..ac20d559 100644 --- a/submitqueue/extension/scorer/heuristic/scorer.go +++ b/submitqueue/extension/scorer/heuristic/scorer.go @@ -20,6 +20,7 @@ import ( "github.com/uber-go/tally" "github.com/uber/submitqueue/core/metrics" + "github.com/uber/submitqueue/submitqueue/core/changeset" "github.com/uber/submitqueue/submitqueue/entity" "github.com/uber/submitqueue/submitqueue/extension/scorer" ) @@ -40,6 +41,8 @@ type Bucket struct { // heuristicScorer computes a success probability by bucketing a metric extracted from a batch of changes. // It follows the Java HeuristicsBasedSuccessPredictor pattern. type heuristicScorer struct { + // resolver resolves the batch identity into its detailed changes. + resolver changeset.Resolver // buckets is the list of ranges to match against. buckets []Bucket // valueFunc extracts the numeric value from a batch of changes. @@ -48,24 +51,30 @@ type heuristicScorer struct { scope tally.Scope } -// New creates a new heuristic Scorer with the given buckets and value function. +// New creates a new heuristic Scorer with the given resolver, buckets and value function. // Panics if valueFunc is nil. -func New(buckets []Bucket, valueFunc ValueFunc, scope tally.Scope) scorer.Scorer { +func New(resolver changeset.Resolver, buckets []Bucket, valueFunc ValueFunc, scope tally.Scope) scorer.Scorer { if valueFunc == nil { panic("heuristic.New: valueFunc must not be nil") } return &heuristicScorer{ + resolver: resolver, buckets: buckets, valueFunc: valueFunc, scope: scope, } } -// Score extracts the value from the batch of changes, then returns the probability score for the -// first bucket whose range [Min, Max] contains the value. Returns an error if no bucket matches. -func (s *heuristicScorer) Score(ctx context.Context, changes entity.BatchChanges) (ret float64, retErr error) { +// Score resolves the batch's changes, extracts the metric, then returns the probability +// score for the first bucket whose range [Min, Max] contains the value. Returns an error +// if no bucket matches. +func (s *heuristicScorer) Score(ctx context.Context, batch entity.Batch) (ret float64, retErr error) { op := metrics.Begin(s.scope, "score") defer func() { op.Complete(retErr) }() + changes, err := s.resolver.DetailedForBatch(ctx, batch) + if err != nil { + return 0, err + } value, err := s.valueFunc(ctx, changes) if err != nil { return 0, err diff --git a/submitqueue/extension/scorer/heuristic/scorer_test.go b/submitqueue/extension/scorer/heuristic/scorer_test.go index d65a9bae..5255de64 100644 --- a/submitqueue/extension/scorer/heuristic/scorer_test.go +++ b/submitqueue/extension/scorer/heuristic/scorer_test.go @@ -21,6 +21,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber-go/tally" + changesetfake "github.com/uber/submitqueue/submitqueue/core/changeset/fake" "github.com/uber/submitqueue/submitqueue/entity" ) @@ -106,8 +107,8 @@ func TestScorer_Score(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := New(tt.buckets, tt.valueFunc, tally.NoopScope) - got, err := s.Score(context.Background(), entity.BatchChanges{}) + s := New(changesetfake.New(), tt.buckets, tt.valueFunc, tally.NoopScope) + got, err := s.Score(context.Background(), entity.Batch{}) if tt.wantErr { require.Error(t, err) return @@ -122,13 +123,13 @@ func TestScorer_Score_ValueFuncError(t *testing.T) { failing := func(_ context.Context, _ entity.BatchChanges) (int, error) { return 0, assert.AnError } - s := New([]Bucket{{Min: 0, Max: 10, Score: 0.9}}, failing, tally.NoopScope) - _, err := s.Score(context.Background(), entity.BatchChanges{}) + s := New(changesetfake.New(), []Bucket{{Min: 0, Max: 10, Score: 0.9}}, failing, tally.NoopScope) + _, err := s.Score(context.Background(), entity.Batch{}) require.Error(t, err) } func TestNew_NilValueFunc(t *testing.T) { assert.Panics(t, func() { - New([]Bucket{{Min: 0, Max: 10, Score: 0.85}}, nil, tally.NoopScope) + New(changesetfake.New(), []Bucket{{Min: 0, Max: 10, Score: 0.85}}, nil, tally.NoopScope) }) } diff --git a/submitqueue/extension/scorer/mock/scorer_mock.go b/submitqueue/extension/scorer/mock/scorer_mock.go index 72edc280..9b64b754 100644 --- a/submitqueue/extension/scorer/mock/scorer_mock.go +++ b/submitqueue/extension/scorer/mock/scorer_mock.go @@ -43,18 +43,18 @@ func (m *MockScorer) EXPECT() *MockScorerMockRecorder { } // Score mocks base method. -func (m *MockScorer) Score(ctx context.Context, changes entity.BatchChanges) (float64, error) { +func (m *MockScorer) Score(ctx context.Context, batch entity.Batch) (float64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Score", ctx, changes) + ret := m.ctrl.Call(m, "Score", ctx, batch) ret0, _ := ret[0].(float64) ret1, _ := ret[1].(error) return ret0, ret1 } // Score indicates an expected call of Score. -func (mr *MockScorerMockRecorder) Score(ctx, changes any) *gomock.Call { +func (mr *MockScorerMockRecorder) Score(ctx, batch any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Score", reflect.TypeOf((*MockScorer)(nil).Score), ctx, changes) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Score", reflect.TypeOf((*MockScorer)(nil).Score), ctx, batch) } // MockFactory is a mock of Factory interface. diff --git a/submitqueue/extension/scorer/scorer.go b/submitqueue/extension/scorer/scorer.go index 6837448e..b3af1b09 100644 --- a/submitqueue/extension/scorer/scorer.go +++ b/submitqueue/extension/scorer/scorer.go @@ -22,11 +22,12 @@ import ( "github.com/uber/submitqueue/submitqueue/entity" ) -// Scorer computes a success probability score for a batch of changes based on their characteristics. +// Scorer computes a success probability score for a batch based on its changes. type Scorer interface { // Score returns a probability between 0.0 and 1.0 indicating the likelihood - // of a successful land for the given batch of changes. - Score(ctx context.Context, changes entity.BatchChanges) (float64, error) + // of a successful land for the given batch. It is handed the batch identity + // and resolves the batch's changes itself through an injected changeset.Resolver. + Score(ctx context.Context, batch entity.Batch) (float64, error) } // Config carries the per-queue identity handed to a Factory. The system knows diff --git a/submitqueue/orchestrator/controller/score/score.go b/submitqueue/orchestrator/controller/score/score.go index 836c4449..6eb6e1f1 100644 --- a/submitqueue/orchestrator/controller/score/score.go +++ b/submitqueue/orchestrator/controller/score/score.go @@ -130,7 +130,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) (r return nil } - // Score each request's change and take the minimum (worst-case) as the batch score + // Score the batch. The scorer resolves the batch's changes itself. batchScore, err := c.scoreBatch(ctx, batch) if err != nil { metrics.NamedCounter(c.metricsScope, opName, "scorer_errors", 1) @@ -173,56 +173,22 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) (r return nil // Success - message will be acked } -// scoreBatch normalizes the batch's changes and scores them as a whole. It resolves -// each request in the batch, reads that request's change records (one per URI), and -// flattens their provider-supplied details into a single entity.BatchChanges, which -// the scorer turns into one probability for the batch. +// scoreBatch builds the queue's scorer and scores the batch. The scorer is handed +// the batch identity and resolves the batch's changes itself (via the shared +// changeset resolver injected at its factory), turning them into one probability. func (c *Controller) scoreBatch(ctx context.Context, batch entity.Batch) (float64, error) { sc, err := c.scorers.For(scorer.Config{QueueName: batch.Queue}) if err != nil { return 0, fmt.Errorf("failed to build scorer for batch %s: %w", batch.ID, err) } - changes, err := c.collectBatchChanges(ctx, batch) - if err != nil { - return 0, err - } - - score, err := sc.Score(ctx, changes) + score, err := sc.Score(ctx, batch) if err != nil { return 0, fmt.Errorf("failed to score batch %s: %w", batch.ID, err) } return score, nil } -// collectBatchChanges assembles the normalized entity.BatchChanges for a batch by -// resolving each request and reading its change records per URI. For each URI it -// selects the record owned by the request (GetByURI returns rows for all requests -// that ever claimed the URI) and appends its URI + details. -func (c *Controller) collectBatchChanges(ctx context.Context, batch entity.Batch) (entity.BatchChanges, error) { - changes := entity.BatchChanges{BatchID: batch.ID, Queue: batch.Queue} - for _, requestID := range batch.Contains { - request, err := c.store.GetRequestStore().Get(ctx, requestID) - if err != nil { - return entity.BatchChanges{}, fmt.Errorf("failed to get request %s: %w", requestID, err) - } - for _, uri := range request.Change.URIs { - records, err := c.store.GetChangeStore().GetByURI(ctx, batch.Queue, uri) - if err != nil { - return entity.BatchChanges{}, fmt.Errorf("failed to read change record for request %s uri=%s: %w", requestID, uri, err) - } - for _, rec := range records { - if rec.RequestID != requestID { - continue - } - changes.Changes = append(changes.Changes, entity.ChangeInfo{URI: rec.URI, Details: rec.Details}) - break - } - } - } - return changes, nil -} - // publish publishes a batch ID to the specified topic key. func (c *Controller) publish(ctx context.Context, key consumer.TopicKey, batchID string, partitionKey string) error { bid := entity.BatchID{ID: batchID} diff --git a/submitqueue/orchestrator/controller/score/score_test.go b/submitqueue/orchestrator/controller/score/score_test.go index a716815f..9a0ca183 100644 --- a/submitqueue/orchestrator/controller/score/score_test.go +++ b/submitqueue/orchestrator/controller/score/score_test.go @@ -166,8 +166,10 @@ func TestController_Process_Success(t *testing.T) { require.NoError(t, err) } -// TestController_Process_BatchLevelScore verifies the controller assembles all of the -// batch's changes into one BatchChanges and persists the single score the scorer returns. +// TestController_Process_BatchLevelScore verifies the controller hands the batch +// identity to the scorer and persists the single score it returns. Resolving the +// batch's changes is the scorer's concern (via the changeset resolver), not the +// controller's. func TestController_Process_BatchLevelScore(t *testing.T) { ctrl := gomock.NewController(t) @@ -179,41 +181,19 @@ func TestController_Process_BatchLevelScore(t *testing.T) { Version: 1, } - request1 := entity.Request{ - ID: "test-queue/1", - Queue: "test-queue", - Change: entity.Change{URIs: []string{"github://uber/repo/pull/1/aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}}, - State: entity.RequestStateStarted, - Version: 1, - } - request2 := entity.Request{ - ID: "test-queue/2", - Queue: "test-queue", - Change: entity.Change{URIs: []string{"github://uber/repo/pull/2/bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"}}, - State: entity.RequestStateStarted, - Version: 1, - } - mockBatchStore := storagemock.NewMockBatchStore(ctrl) mockBatchStore.EXPECT().Get(gomock.Any(), batch.ID).Return(batch, nil) // The single batch-level score is persisted. mockBatchStore.EXPECT().UpdateScoreAndState(gomock.Any(), batch.ID, batch.Version, batch.Version+1, 0.7, entity.BatchStateScored).Return(nil) - mockRequestStore := storagemock.NewMockRequestStore(ctrl) - mockRequestStore.EXPECT().Get(gomock.Any(), "test-queue/1").Return(request1, nil) - mockRequestStore.EXPECT().Get(gomock.Any(), "test-queue/2").Return(request2, nil) - store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() - store.EXPECT().GetRequestStore().Return(mockRequestStore).AnyTimes() - store.EXPECT().GetChangeStore().Return(mockChangeStore(ctrl, request1, request2)).AnyTimes() - // Capture the BatchChanges to assert both requests' changes were gathered. + // The controller passes the batch identity to the scorer and persists its score. mockScorer := scorermock.NewMockScorer(ctrl) mockScorer.EXPECT().Score(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, changes entity.BatchChanges) (float64, error) { - assert.Equal(t, batch.ID, changes.BatchID) - assert.Len(t, changes.Changes, 2) + func(_ context.Context, b entity.Batch) (float64, error) { + assert.Equal(t, batch.ID, b.ID) return 0.7, nil }, ) @@ -260,7 +240,7 @@ func TestController_Process_ScorerFailure(t *testing.T) { mockBatchStore.EXPECT().Get(gomock.Any(), batch.ID).Return(batch, nil) mockRequestStore := storagemock.NewMockRequestStore(ctrl) - mockRequestStore.EXPECT().Get(gomock.Any(), request.ID).Return(request, nil) + mockRequestStore.EXPECT().Get(gomock.Any(), request.ID).Return(request, nil).AnyTimes() store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() @@ -292,7 +272,7 @@ func TestController_Process_UpdateScoreFailure(t *testing.T) { mockBatchStore.EXPECT().UpdateScoreAndState(gomock.Any(), batch.ID, batch.Version, batch.Version+1, gomock.Any(), entity.BatchStateScored).Return(fmt.Errorf("version mismatch")) mockRequestStore := storagemock.NewMockRequestStore(ctrl) - mockRequestStore.EXPECT().Get(gomock.Any(), request.ID).Return(request, nil) + mockRequestStore.EXPECT().Get(gomock.Any(), request.ID).Return(request, nil).AnyTimes() store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes()