Index: impl/prod/raw_datastore.go |
diff --git a/impl/prod/raw_datastore.go b/impl/prod/raw_datastore.go |
index 0eee913e1e9e239c64f54dbce67038da5a0d1678..de70c43f491dcbd430afc49be1806c22f1be5a83 100644 |
--- a/impl/prod/raw_datastore.go |
+++ b/impl/prod/raw_datastore.go |
@@ -14,17 +14,13 @@ import ( |
// useRDS adds a gae.RawDatastore implementation to context, accessible |
// by gae.GetDS(c) |
func useRDS(c context.Context) context.Context { |
- return ds.SetRawFactory(c, func(ci context.Context, wantTxn bool) ds.RawInterface { |
- maybeTxnCtx := AEContext(ci) |
- |
- if wantTxn { |
- return rdsImpl{ci, maybeTxnCtx} |
- } |
- aeCtx := AEContextNoTxn(ci) |
- if maybeTxnCtx != aeCtx { |
- ci = context.WithValue(ci, prodContextKey, aeCtx) |
+ return ds.SetRawFactory(c, func(ci context.Context) ds.RawInterface { |
+ rds := rdsImpl{ |
+ userCtx: ci, |
+ ps: getProdState(ci), |
} |
- return rdsImpl{ci, aeCtx} |
+ rds.aeCtx = rds.ps.context(ci) |
+ return &rds |
}) |
} |
@@ -35,8 +31,12 @@ type rdsImpl struct { |
// it. |
userCtx context.Context |
- // aeCtx is the context with the appengine connection information in it. |
+ // aeCtx is the AppEngine Context that will be used in method calls. This is |
+ // derived from ps. |
aeCtx context.Context |
+ |
+ // ps is the current production state. |
+ ps prodState |
} |
func idxCallbacker(err error, amt int, cb func(idx int, err error) error) error { |
@@ -61,7 +61,7 @@ func idxCallbacker(err error, amt int, cb func(idx int, err error) error) error |
return err |
} |
-func (d rdsImpl) AllocateIDs(keys []*ds.Key, cb ds.NewKeyCB) error { |
+func (d *rdsImpl) AllocateIDs(keys []*ds.Key, cb ds.NewKeyCB) error { |
// Map keys by entity type. |
entityMap := make(map[string][]int) |
for i, key := range keys { |
@@ -106,7 +106,7 @@ func (d rdsImpl) AllocateIDs(keys []*ds.Key, cb ds.NewKeyCB) error { |
return nil |
} |
-func (d rdsImpl) DeleteMulti(ks []*ds.Key, cb ds.DeleteMultiCB) error { |
+func (d *rdsImpl) DeleteMulti(ks []*ds.Key, cb ds.DeleteMultiCB) error { |
keys, err := dsMF2R(d.aeCtx, ks) |
if err == nil { |
err = datastore.DeleteMulti(d.aeCtx, keys) |
@@ -116,7 +116,7 @@ func (d rdsImpl) DeleteMulti(ks []*ds.Key, cb ds.DeleteMultiCB) error { |
}) |
} |
-func (d rdsImpl) GetMulti(keys []*ds.Key, _meta ds.MultiMetaGetter, cb ds.GetMultiCB) error { |
+func (d *rdsImpl) GetMulti(keys []*ds.Key, _meta ds.MultiMetaGetter, cb ds.GetMultiCB) error { |
vals := make([]datastore.PropertyLoadSaver, len(keys)) |
rkeys, err := dsMF2R(d.aeCtx, keys) |
if err == nil { |
@@ -133,7 +133,7 @@ func (d rdsImpl) GetMulti(keys []*ds.Key, _meta ds.MultiMetaGetter, cb ds.GetMul |
}) |
} |
-func (d rdsImpl) PutMulti(keys []*ds.Key, vals []ds.PropertyMap, cb ds.NewKeyCB) error { |
+func (d *rdsImpl) PutMulti(keys []*ds.Key, vals []ds.PropertyMap, cb ds.NewKeyCB) error { |
rkeys, err := dsMF2R(d.aeCtx, keys) |
if err == nil { |
rvals := make([]datastore.PropertyLoadSaver, len(vals)) |
@@ -151,7 +151,7 @@ func (d rdsImpl) PutMulti(keys []*ds.Key, vals []ds.PropertyMap, cb ds.NewKeyCB) |
}) |
} |
-func (d rdsImpl) fixQuery(fq *ds.FinalizedQuery) (*datastore.Query, error) { |
+func (d *rdsImpl) fixQuery(fq *ds.FinalizedQuery) (*datastore.Query, error) { |
ret := datastore.NewQuery(fq.Kind()) |
start, end := fq.Bounds() |
@@ -226,11 +226,11 @@ func (d rdsImpl) fixQuery(fq *ds.FinalizedQuery) (*datastore.Query, error) { |
return ret, nil |
} |
-func (d rdsImpl) DecodeCursor(s string) (ds.Cursor, error) { |
+func (d *rdsImpl) DecodeCursor(s string) (ds.Cursor, error) { |
return datastore.DecodeCursor(s) |
} |
-func (d rdsImpl) Run(fq *ds.FinalizedQuery, cb ds.RawRunCB) error { |
+func (d *rdsImpl) Run(fq *ds.FinalizedQuery, cb ds.RawRunCB) error { |
q, err := d.fixQuery(fq) |
if err != nil { |
return err |
@@ -256,7 +256,7 @@ func (d rdsImpl) Run(fq *ds.FinalizedQuery, cb ds.RawRunCB) error { |
} |
} |
-func (d rdsImpl) Count(fq *ds.FinalizedQuery) (int64, error) { |
+func (d *rdsImpl) Count(fq *ds.FinalizedQuery) (int64, error) { |
q, err := d.fixQuery(fq) |
if err != nil { |
return 0, err |
@@ -265,13 +265,40 @@ func (d rdsImpl) Count(fq *ds.FinalizedQuery) (int64, error) { |
return int64(ret), err |
} |
-func (d rdsImpl) RunInTransaction(f func(c context.Context) error, opts *ds.TransactionOptions) error { |
+func (d *rdsImpl) RunInTransaction(f func(c context.Context) error, opts *ds.TransactionOptions) error { |
ropts := (*datastore.TransactionOptions)(opts) |
return datastore.RunInTransaction(d.aeCtx, func(c context.Context) error { |
- return f(context.WithValue(d.userCtx, prodContextKey, c)) |
+ // Derive a prodState with this transaction Context. |
+ ps := d.ps |
+ ps.ctx = c |
+ ps.inTxn = true |
+ |
+ c = withProdState(d.userCtx, ps) |
+ return f(c) |
}, ropts) |
} |
-func (d rdsImpl) Testable() ds.Testable { |
+func (d *rdsImpl) WithoutTransaction() context.Context { |
+ c := d.userCtx |
+ if d.ps.inTxn { |
+ // We're in a transaction. Reset to non-transactional state. |
+ ps := d.ps |
+ ps.ctx = ps.noTxnCtx |
+ ps.inTxn = false |
+ c = withProdState(c, ps) |
+ } |
+ return c |
+} |
+ |
+func (d *rdsImpl) CurrentTransaction() ds.Transaction { |
+ if d.ps.inTxn { |
+ // Since we don't distinguish between transactions (yet), we just need this |
+ // to be non-nil. |
+ return struct{}{} |
+ } |
+ return nil |
+} |
+ |
+func (d *rdsImpl) GetTestable() ds.Testable { |
return nil |
} |