diff --git a/pkg/queue/queue.go b/pkg/queue/queue.go new file mode 100644 index 0000000000000000000000000000000000000000..356f17f4125765f6822551358b79cc1037cf17b8 --- /dev/null +++ b/pkg/queue/queue.go @@ -0,0 +1,169 @@ +package queue + +import ( + "runtime" + "sync" + "time" + + "git.perx.ru/perxis/perxis-go/pkg/errors" + "git.perx.ru/perxis/perxis-go/pkg/id" +) + +const ( + defaultSize = 100 + defaultStoreResultsTTL = 1 * time.Hour +) + +type Waiter interface { + Wait(string) (*JobResult, error) +} + +type Job func() error + +type JobResult struct { + Err error +} + +type JobGroup struct { + wg sync.WaitGroup +} + +func (g *JobGroup) Add(j Job) Job { + g.wg.Add(1) + return func() error { + defer g.wg.Done() + return j() + } +} + +func (g *JobGroup) Wait() { + g.wg.Wait() +} + +type Queue struct { + jobsCh chan Job + results sync.Map + StoreResultsTTL time.Duration + serveWG sync.WaitGroup + done chan struct{} + NumWorkers int + Size int +} + +func (j *Queue) AddJob(job Job) (jobID string, err error) { + if j == nil { + return + } + + jobID = id.GenerateNewID() + resCh := make(chan *JobResult, 1) + + trackedJob := func() error { + err := job() + resCh <- &JobResult{Err: err} + close(resCh) + + go func() { + select { + case <-j.done: + return + case <-time.After(j.StoreResultsTTL): + j.results.Delete(jobID) + } + }() + + return err + } + + // нужно добавить до того, как задача попадет в очередь + j.results.Store(jobID, resCh) + + select { + case j.jobsCh <- trackedJob: + default: + j.results.Delete(jobID) + return "", errors.New("queue size exceeded") + } + + return jobID, nil +} + +func (j *Queue) Wait(jobID string) (*JobResult, error) { + resCh, ok := j.results.Load(jobID) + if !ok { + return nil, errors.Errorf("job '%s' not found", jobID) + } + res := <-resCh.(chan *JobResult) + return res, nil +} + +func (j *Queue) WaitCh(jobID string) (<-chan *JobResult, error) { + resCh, ok := j.results.Load(jobID) + if !ok { + return nil, errors.Errorf("job '%s' not found", jobID) + } + return resCh.(chan *JobResult), nil +} + +func (j *Queue) IsStarted() bool { + return j != nil && j.jobsCh != nil && j.done != nil +} + +func (j *Queue) Start() { + if j == nil { + panic("job runner not created") + } + + if j.jobsCh != nil || j.done != nil { + return + } + + if j.Size == 0 { + j.Size = defaultSize + } + + j.jobsCh = make(chan Job, j.Size) + j.done = make(chan struct{}) + j.serveWG = sync.WaitGroup{} + + if j.StoreResultsTTL == 0 { + j.StoreResultsTTL = defaultStoreResultsTTL + } + + if j.NumWorkers == 0 { + j.NumWorkers = runtime.NumCPU() + } + + j.serveWG.Add(j.NumWorkers) + for i := 0; i < j.NumWorkers; i++ { + go j.worker() + } +} + +func (j *Queue) Stop() { + if j.done == nil && j.jobsCh == nil { + return + } + + close(j.done) + j.serveWG.Wait() + close(j.jobsCh) + j.done = nil + j.jobsCh = nil +} + +func (j *Queue) worker() { + defer j.serveWG.Done() + + for { + select { + case job, ok := <-j.jobsCh: + if !ok { + return // channel closed + } + _ = job() + case <-j.done: + return + } + } +} diff --git a/pkg/queue/queue_test.go b/pkg/queue/queue_test.go new file mode 100644 index 0000000000000000000000000000000000000000..b1c28ff8c1d9e7a3c9f0bca32db47f4418401c09 --- /dev/null +++ b/pkg/queue/queue_test.go @@ -0,0 +1,60 @@ +package queue + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func f(v *int32, timeout int) (err error) { + atomic.AddInt32(v, 1) + time.Sleep(time.Millisecond * time.Duration(timeout)) + return +} + +func TestDefaultQueueRun(t *testing.T) { + var v int32 + q := &Queue{} + q.Start() + + jg := &JobGroup{} + for i := 0; i < 100; i++ { + j := func() (err error) { + return f(&v, 1) + } + _, err := q.AddJob(jg.Add(j)) + require.NoError(t, err) + } + jg.Wait() + q.Stop() + assert.Equal(t, int32(100), v) +} + +func TestQueueSizeExceededError(t *testing.T) { + var v int32 + q := &Queue{NumWorkers: 1, Size: 2} + q.Start() + + // попадает в очередь без ошибки + j := func() (err error) { + return f(&v, 10) + } + + var err error + var i int + for i = 0; i < 5; i++ { + _, err = q.AddJob(j) + if err != nil { + break + } + } + + require.Error(t, err) + assert.Equal(t, "queue size exceeded", err.Error()) + assert.Greater(t, i, 1) + + q.Stop() +}