Index: impl/memory/taskqueue.go |
diff --git a/impl/memory/taskqueue.go b/impl/memory/taskqueue.go |
index 2ee29fe0dfd392f5a8c0b1e31db351418be9b29c..b53ff6f1acbcc5a2435f3632ddb7432d148deaac 100644 |
--- a/impl/memory/taskqueue.go |
+++ b/impl/memory/taskqueue.go |
@@ -6,11 +6,11 @@ package memory |
import ( |
"regexp" |
- "sync/atomic" |
"golang.org/x/net/context" |
tq "github.com/luci/gae/service/taskqueue" |
+ |
"github.com/luci/luci-go/common/data/rand/mathrand" |
"github.com/luci/luci-go/common/errors" |
) |
@@ -18,20 +18,14 @@ import ( |
/////////////////////////////// public functions /////////////////////////////// |
func useTQ(c context.Context) context.Context { |
- return tq.SetRawFactory(c, func(ic context.Context, wantTxn bool) tq.RawInterface { |
- ns, _ := curGID(ic).getNamespace() |
- var tqd memContextObj |
- |
- if !wantTxn { |
- tqd = curNoTxn(ic).Get(memContextTQIdx) |
- } else { |
- tqd = cur(ic).Get(memContextTQIdx) |
- } |
+ return tq.SetRawFactory(c, func(ic context.Context) tq.RawInterface { |
+ memCtx, isTxn := cur(ic) |
+ tqd := memCtx.Get(memContextTQIdx) |
- if x, ok := tqd.(*taskQueueData); ok { |
- return &taskqueueImpl{x, ic, ns} |
+ if isTxn { |
+ return &taskqueueTxnImpl{tqd.(*txnTaskQueueData), ic} |
} |
- return &taskqueueTxnImpl{tqd.(*txnTaskQueueData), ic, ns} |
+ return &taskqueueImpl{tqd.(*taskQueueData), ic} |
}) |
} |
@@ -41,7 +35,6 @@ type taskqueueImpl struct { |
*taskQueueData |
ctx context.Context |
- ns string |
} |
var ( |
@@ -50,7 +43,7 @@ var ( |
) |
func (t *taskqueueImpl) addLocked(task *tq.Task, queueName string) (*tq.Task, error) { |
- toSched, err := t.prepTask(t.ctx, t.ns, task, queueName) |
+ toSched, err := t.prepTask(t.ctx, task, queueName) |
if err != nil { |
return nil, err |
} |
@@ -145,9 +138,7 @@ func (t *taskqueueImpl) Stats(queueNames []string, cb tq.RawStatsCB) error { |
return nil |
} |
-func (t *taskqueueImpl) Testable() tq.Testable { |
- return t |
-} |
+func (t *taskqueueImpl) GetTestable() tq.Testable { return t } |
/////////////////////////////// taskqueueTxnImpl /////////////////////////////// |
@@ -155,7 +146,6 @@ type taskqueueTxnImpl struct { |
*txnTaskQueueData |
ctx context.Context |
- ns string |
} |
var _ interface { |
@@ -164,7 +154,7 @@ var _ interface { |
} = (*taskqueueTxnImpl)(nil) |
func (t *taskqueueTxnImpl) addLocked(task *tq.Task, queueName string) (*tq.Task, error) { |
- toSched, err := t.parent.prepTask(t.ctx, t.ns, task, queueName) |
+ toSched, err := t.parent.prepTask(t.ctx, task, queueName) |
if err != nil { |
return nil, err |
} |
@@ -196,8 +186,8 @@ func (t *taskqueueTxnImpl) addLocked(task *tq.Task, queueName string) (*tq.Task, |
} |
func (t *taskqueueTxnImpl) AddMulti(tasks []*tq.Task, queueName string, cb tq.RawTaskCB) error { |
- if atomic.LoadInt32(&t.closed) == 1 { |
- return errors.New("taskqueue: transaction context has expired") |
+ if err := assertTxnValid(t.ctx); err != nil { |
+ return err |
} |
t.Lock() |
@@ -226,9 +216,7 @@ func (t *taskqueueTxnImpl) Stats([]string, tq.RawStatsCB) error { |
return errors.New("taskqueue: cannot Stats from a transaction") |
} |
-func (t *taskqueueTxnImpl) Testable() tq.Testable { |
- return t |
-} |
+func (t *taskqueueTxnImpl) GetTestable() tq.Testable { return t } |
////////////////////////////// private functions /////////////////////////////// |