Skip to content
Snippets Groups Projects
Commit e69d6b88 authored by Sergey Yakubov's avatar Sergey Yakubov
Browse files

replace go mongo driver

parent e418f876
No related branches found
No related tags found
No related merge requests found
Showing
with 145 additions and 153 deletions
......@@ -5,7 +5,6 @@ type Agent interface {
Ping() error
Connect(string) error
Close()
Copy() Agent
}
type DBError struct {
......
......@@ -16,7 +16,6 @@ func TestMockDataBase(t *testing.T) {
db.Connect("")
db.Close()
db.Copy()
db.Ping()
var err DBError
err.Error()
......
......@@ -24,11 +24,6 @@ func (db *MockedDatabase) Ping() error {
return args.Error(0)
}
func (db *MockedDatabase) Copy() Agent {
db.Called()
return db
}
func (db *MockedDatabase) ProcessRequest(db_name string, group_id string, op string, extra_param string) (answer []byte, err error) {
args := db.Called(db_name, group_id, op, extra_param)
return args.Get(0).([]byte), args.Error(1)
......
......@@ -5,26 +5,32 @@ package database
import (
"asapo_common/logger"
"asapo_common/utils"
"context"
"encoding/json"
"errors"
"github.com/globalsign/mgo"
"github.com/globalsign/mgo/bson"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"strconv"
"strings"
"sync"
"time"
)
type Pointer struct {
ID int `bson:"_id"`
Value int `bson:"current_pointer"`
type ID struct {
ID int `bson:"_id"`
}
type LocationPointer struct {
GroupID string `bson:"_id"`
Value int `bson:"current_pointer"`
}
const data_collection_name = "data"
const meta_collection_name = "meta"
const pointer_collection_name = "current_location"
const pointer_field_name = "current_pointer"
const no_session_msg = "database session not created"
const no_session_msg = "database client not created"
const wrong_id_type = "wrong id type"
const already_connected_msg = "already connected"
......@@ -37,24 +43,13 @@ type SizeRecord struct {
}
type Mongodb struct {
session *mgo.Session
client *mongo.Client
timeout time.Duration
databases []string
parent_db *Mongodb
db_pointers_created map[string]bool
}
func (db *Mongodb) Copy() Agent {
new_db := new(Mongodb)
if db.session != nil {
dbSessionLock.RLock()
new_db.session = db.session.Copy()
dbSessionLock.RUnlock()
}
new_db.parent_db = db
return new_db
}
func (db *Mongodb) databaseInList(dbname string) bool {
dbListLock.RLock()
defer dbListLock.RUnlock()
......@@ -63,16 +58,20 @@ func (db *Mongodb) databaseInList(dbname string) bool {
func (db *Mongodb) updateDatabaseList() (err error) {
dbListLock.Lock()
db.databases, err = db.session.DatabaseNames()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
db.databases, err = db.client.ListDatabaseNames(ctx, bson.M{})
dbListLock.Unlock()
return err
}
func (db *Mongodb) Ping() (err error) {
if db.session == nil {
if db.client == nil {
return &DBError{utils.StatusServiceUnavailable, no_session_msg}
}
return db.session.Ping()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
return db.client.Ping(ctx, nil)
}
func (db *Mongodb) dataBaseExist(dbname string) (err error) {
......@@ -92,18 +91,26 @@ func (db *Mongodb) dataBaseExist(dbname string) (err error) {
}
func (db *Mongodb) Connect(address string) (err error) {
if db.session != nil {
if db.client != nil {
return &DBError{utils.StatusServiceUnavailable, already_connected_msg}
}
db.session, err = mgo.DialWithTimeout(address, time.Second)
db.client, err = mongo.NewClient(options.Client().SetConnectTimeout(20 * time.Second).ApplyURI("mongodb://" + address))
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
err = db.client.Connect(ctx)
if err != nil {
db.client = nil
return err
}
// db.session.SetSafe(&mgo.Safe{J: true})
// db.client.SetSafe(&mgo.Safe{J: true})
if err := db.updateDatabaseList(); err != nil {
db.Close()
return err
}
......@@ -111,90 +118,94 @@ func (db *Mongodb) Connect(address string) (err error) {
}
func (db *Mongodb) Close() {
if db.session != nil {
if db.client != nil {
dbSessionLock.Lock()
db.session.Close()
db.session = nil
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
db.client.Disconnect(ctx)
db.client = nil
dbSessionLock.Unlock()
}
}
func (db *Mongodb) deleteAllRecords(dbname string) (err error) {
if db.session == nil {
if db.client == nil {
return &DBError{utils.StatusServiceUnavailable, no_session_msg}
}
return db.session.DB(dbname).DropDatabase()
return db.client.Database(dbname).Drop(context.TODO())
}
func (db *Mongodb) insertRecord(dbname string, s interface{}) error {
if db.session == nil {
if db.client == nil {
return &DBError{utils.StatusServiceUnavailable, no_session_msg}
}
c := db.session.DB(dbname).C(data_collection_name)
c := db.client.Database(dbname).Collection(data_collection_name)
return c.Insert(s)
_, err := c.InsertOne(context.TODO(), s)
return err
}
func (db *Mongodb) insertMeta(dbname string, s interface{}) error {
if db.session == nil {
if db.client == nil {
return &DBError{utils.StatusServiceUnavailable, no_session_msg}
}
c := db.session.DB(dbname).C(meta_collection_name)
c := db.client.Database(dbname).Collection(meta_collection_name)
return c.Insert(s)
_, err := c.InsertOne(context.TODO(), s)
return err
}
func (db *Mongodb) getMaxIndex(dbname string, dataset bool) (max_id int, err error) {
c := db.session.DB(dbname).C(data_collection_name)
var id Pointer
c := db.client.Database(dbname).Collection(data_collection_name)
var q bson.M
if dataset {
q = bson.M{"$expr": bson.M{"$eq": []interface{}{"$size", bson.M{"$size": "$images"}}}}
} else {
q = nil
}
err = c.Find(q).Sort("-_id").Select(bson.M{"_id": 1}).One(&id)
if err == mgo.ErrNotFound {
opts := options.FindOne().SetSort(bson.M{"_id": -1}).SetReturnKey(true)
var result ID
err = c.FindOne(context.TODO(), q, opts).Decode(&result)
if err == mongo.ErrNoDocuments {
return 0, nil
}
return id.ID, err
return result.ID, err
}
func (db *Mongodb) createLocationPointers(dbname string, group_id string) (err error) {
change := mgo.Change{
Update: bson.M{"$inc": bson.M{pointer_field_name: 0}},
Upsert: true,
}
opts := options.Update().SetUpsert(true)
update := bson.M{"$inc": bson.M{pointer_field_name: 0}}
q := bson.M{"_id": group_id}
c := db.session.DB(dbname).C(pointer_collection_name)
var res map[string]interface{}
_, err = c.Find(q).Apply(change, &res)
return err
c := db.client.Database(dbname).Collection(pointer_collection_name)
_, err = c.UpdateOne(context.TODO(), q, update, opts)
return
}
func (db *Mongodb) setCounter(dbname string, group_id string, ind int) (err error) {
update := bson.M{"$set": bson.M{pointer_field_name: ind}}
c := db.session.DB(dbname).C(pointer_collection_name)
return c.UpdateId(group_id, update)
c := db.client.Database(dbname).Collection(pointer_collection_name)
q := bson.M{"_id": group_id}
_, err = c.UpdateOne(context.TODO(), q, update, options.Update())
return
}
func (db *Mongodb) incrementField(dbname string, group_id string, max_ind int, res interface{}) (err error) {
update := bson.M{"$inc": bson.M{pointer_field_name: 1}}
change := mgo.Change{
Update: update,
Upsert: false,
ReturnNew: true,
}
opts := options.FindOneAndUpdate().SetUpsert(false).SetReturnDocument(options.After)
q := bson.M{"_id": group_id, pointer_field_name: bson.M{"$lt": max_ind}}
c := db.session.DB(dbname).C(pointer_collection_name)
_, err = c.Find(q).Apply(change, res)
if err == mgo.ErrNotFound {
return &DBError{utils.StatusNoData, encodeAnswer(max_ind, max_ind)}
} else if err != nil { // we do not know if counter was updated
c := db.client.Database(dbname).Collection(pointer_collection_name)
err = c.FindOneAndUpdate(context.TODO(), q, update, opts).Decode(res)
if err != nil {
if err == mongo.ErrNoDocuments {
return &DBError{utils.StatusNoData, encodeAnswer(max_ind, max_ind)}
}
return &DBError{utils.StatusTransactionInterrupted, err.Error()}
}
return nil
}
......@@ -217,8 +228,8 @@ func (db *Mongodb) getRecordByIDRow(dbname string, id, id_max int, dataset bool)
q = bson.M{"_id": id}
}
c := db.session.DB(dbname).C(data_collection_name)
err := c.Find(q).One(&res)
c := db.client.Database(dbname).Collection(data_collection_name)
err := c.FindOne(context.TODO(), q, options.FindOne()).Decode(&res)
if err != nil {
answer := encodeAnswer(id, id_max)
log_str := "error getting record id " + strconv.Itoa(id) + " for " + dbname + " : " + err.Error()
......@@ -277,7 +288,7 @@ func (db *Mongodb) getParentDB() *Mongodb {
}
func (db *Mongodb) checkDatabaseOperationPrerequisites(db_name string, group_id string) error {
if db.session == nil {
if db.client == nil {
return &DBError{utils.StatusServiceUnavailable, no_session_msg}
}
......@@ -291,16 +302,16 @@ func (db *Mongodb) checkDatabaseOperationPrerequisites(db_name string, group_id
return nil
}
func (db *Mongodb) getCurrentPointer(db_name string, group_id string, dataset bool) (Pointer, int, error) {
func (db *Mongodb) getCurrentPointer(db_name string, group_id string, dataset bool) (LocationPointer, int, error) {
max_ind, err := db.getMaxIndex(db_name, dataset)
if err != nil {
return Pointer{}, 0, err
return LocationPointer{}, 0, err
}
var curPointer Pointer
var curPointer LocationPointer
err = db.incrementField(db_name, group_id, max_ind, &curPointer)
if err != nil {
return Pointer{}, 0, err
return LocationPointer{}, 0, err
}
return curPointer, max_ind, nil
......@@ -331,13 +342,15 @@ func (db *Mongodb) getLastRecord(db_name string, group_id string, dataset bool)
}
func (db *Mongodb) getSize(db_name string) ([]byte, error) {
c := db.session.DB(db_name).C(data_collection_name)
c := db.client.Database(db_name).Collection(data_collection_name)
var rec SizeRecord
var err error
rec.Size, err = c.Count()
size, err := c.CountDocuments(context.TODO(), bson.M{}, options.Count())
if err != nil {
return nil, err
}
rec.Size = int(size)
return json.Marshal(&rec)
}
......@@ -360,8 +373,8 @@ func (db *Mongodb) getMeta(dbname string, id_str string) ([]byte, error) {
var res map[string]interface{}
q := bson.M{"_id": id}
c := db.session.DB(dbname).C(meta_collection_name)
err = c.Find(q).One(&res)
c := db.client.Database(dbname).Collection(meta_collection_name)
err = c.FindOne(context.TODO(), q, options.FindOne()).Decode(&res)
if err != nil {
log_str := "error getting meta with id " + strconv.Itoa(id) + " for " + dbname + " : " + err.Error()
logger.Debug(log_str)
......@@ -372,6 +385,12 @@ func (db *Mongodb) getMeta(dbname string, id_str string) ([]byte, error) {
return utils.MapToJson(&res)
}
func (db *Mongodb) processQueryError(query, dbname string, err error) ([]byte, error) {
log_str := "error processing query: " + query + " for " + dbname + " : " + err.Error()
logger.Debug(log_str)
return nil, &DBError{utils.StatusNoData, err.Error()}
}
func (db *Mongodb) queryImages(dbname string, query string) ([]byte, error) {
var res []map[string]interface{}
q, sort, err := db.BSONFromSQL(dbname, query)
......@@ -381,16 +400,21 @@ func (db *Mongodb) queryImages(dbname string, query string) ([]byte, error) {
return nil, &DBError{utils.StatusWrongInput, err.Error()}
}
c := db.session.DB(dbname).C(data_collection_name)
c := db.client.Database(dbname).Collection(data_collection_name)
opts := options.Find()
if len(sort) > 0 {
err = c.Find(q).Sort(sort).All(&res)
opts = opts.SetSort(sort)
} else {
err = c.Find(q).All(&res)
}
cursor, err := c.Find(context.TODO(), q, opts)
if err != nil {
log_str := "error processing query: " + query + " for " + dbname + " : " + err.Error()
logger.Debug(log_str)
return nil, &DBError{utils.StatusNoData, err.Error()}
return db.processQueryError(query, dbname, err)
}
err = cursor.All(context.TODO(), &res)
if err != nil {
return db.processQueryError(query, dbname, err)
}
log_str := "processed query " + query + " for " + dbname + " ,found" + strconv.Itoa(len(res)) + " records"
......
......@@ -178,39 +178,40 @@ func getBSONFromExpression(node sqlparser.Expr) (res bson.M, err error) {
}
}
func getSortBSONFromOrderArray(order_array sqlparser.OrderBy) (string, error) {
func getSortBSONFromOrderArray(order_array sqlparser.OrderBy) (bson.M, error) {
if len(order_array) != 1 {
return "", errors.New("order by should have single column name")
return bson.M{}, errors.New("order by should have single column name")
}
order := order_array[0]
val, ok := order.Expr.(*sqlparser.ColName)
if !ok {
return "", errors.New("order be key name")
return bson.M{}, errors.New("order has to be key name")
}
name := keyFromColumnName(val)
sign := 1
if order.Direction == sqlparser.DescScr {
name = "-" + name
sign = -1
}
return name, nil
return bson.M{name: sign}, nil
}
func (db *Mongodb) BSONFromSQL(dbname string, query string) (bson.M, string, error) {
func (db *Mongodb) BSONFromSQL(dbname string, query string) (bson.M, bson.M, error) {
stmt, err := sqlparser.Parse("select * from " + dbname + " where " + query)
if err != nil {
return bson.M{}, "", err
return bson.M{}, bson.M{}, err
}
sel, _ := stmt.(*sqlparser.Select)
query_mongo, err := getBSONFromExpression(sel.Where.Expr)
if err != nil || len(sel.OrderBy) == 0 {
return query_mongo, "", err
return query_mongo, bson.M{}, err
}
sort_mongo, err := getSortBSONFromOrderArray(sel.OrderBy)
if err != nil {
return bson.M{}, "", err
return bson.M{}, bson.M{}, err
}
return query_mongo, sort_mongo, nil
......
......@@ -4,10 +4,11 @@ package database
import (
"asapo_common/utils"
"context"
"encoding/json"
"fmt"
"github.com/stretchr/testify/assert"
"strings"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/options"
"sync"
"testing"
)
......@@ -45,13 +46,16 @@ var recs2 = SizeRecord{0}
var recs2_expect, _ = json.Marshal(recs2)
func cleanup() {
if db.client == nil {
return
}
db.deleteAllRecords(dbname)
db.db_pointers_created = nil
db.Close()
}
// these are tjhe integration tests. They assume mongo db is runnig on 127.0.0.1:27027
// test names shlud contain MongoDB*** so that go test could find them:
// these are the integration tests. They assume mongo db is runnig on 127.0.0.1:27027
// test names should contain MongoDB*** so that go test could find them:
// go_integration_test(${TARGET_NAME}-connectdb "./..." "MongoDBConnect")
func TestMongoDBConnectFails(t *testing.T) {
err := db.Connect("blabla")
......@@ -65,12 +69,6 @@ func TestMongoDBConnectOK(t *testing.T) {
assert.Nil(t, err)
}
func TestMongoCopyWhenNoSession(t *testing.T) {
db_new := db.Copy()
err := db_new.Connect("sss")
assert.NotNil(t, err)
}
func TestMongoDBGetNextErrorWhenNotConnected(t *testing.T) {
_, err := db.ProcessRequest(dbname, groupId, "next", "")
assert.Equal(t, utils.StatusServiceUnavailable, err.(*DBError).Code)
......@@ -130,14 +128,6 @@ func TestMongoDBGetNextErrorOnNoMoreData(t *testing.T) {
assert.Equal(t, "{\"op\":\"get_record_by_id\",\"id\":1,\"id_max\":1}", err.(*DBError).Message)
}
//func TestMongoDBGetNextErrorOnDataAtAll(t *testing.T) {
// db.Connect(dbaddress)
// defer cleanup()
// _, err := db.GetNextRecord(dbname, groupId, false)
// assert.Equal(t, utils.StatusNoData, err.(*DBError).Code)
// assert.Equal(t, "{\"op\":\"get_record_by_id\",\"id\":0,\"id_max\":0}", err.(*DBError).Message)
//}
func TestMongoDBGetNextCorrectOrder(t *testing.T) {
db.Connect(dbaddress)
defer cleanup()
......@@ -288,7 +278,7 @@ func TestMongoDBGetSizeNoRecords(t *testing.T) {
defer cleanup()
// to have empty collection
db.insertRecord(dbname, &rec1)
db.session.DB(dbname).C(data_collection_name).RemoveId(1)
db.client.Database(dbname).Collection(data_collection_name).DeleteOne(context.TODO(), bson.M{"_id": 1}, options.Delete())
res, err := db.ProcessRequest(dbname, "", "size", "0")
assert.Nil(t, err)
......@@ -424,18 +414,19 @@ func TestMongoDBQueryImagesOK(t *testing.T) {
db.insertRecord(dbname, &recq4)
for _, test := range tests {
info, _ := db.session.BuildInfo()
if strings.Contains(test.query, "NOT REGEXP") && !info.VersionAtLeast(4, 0, 7) {
fmt.Println("Skipping NOT REGEXP test since it is not supported by this mongodb version")
continue
}
// info, _ := db.client.BuildInfo()
// if strings.Contains(test.query, "NOT REGEXP") && !info.VersionAtLeast(4, 0, 7) {
// fmt.Println("Skipping NOT REGEXP test since it is not supported by this mongodb version")
// continue
// }
res_string, err := db.ProcessRequest(dbname, "", "queryimages", test.query)
var res []TestRecordMeta
json.Unmarshal(res_string, &res)
// fmt.Println(string(res_string))
if test.ok {
assert.Nil(t, err)
assert.Equal(t, test.res, res, test.query)
assert.Nil(t, err, test.query)
assert.Equal(t, test.res, res)
} else {
assert.NotNil(t, err, test.query)
assert.Equal(t, 0, len(res))
......
......@@ -5,9 +5,7 @@ import (
)
func routeGetHealth(w http.ResponseWriter, r *http.Request) {
db_new := db.Copy()
defer db_new.Close()
err := db_new.Ping()
err := db.Ping()
if err != nil {
ReconnectDb()
}
......
......@@ -33,15 +33,13 @@ func TestGetHealthTestSuite(t *testing.T) {
func (suite *GetHealthTestSuite) TestGetHealthOk() {
suite.mock_db.On("Ping").Return(nil)
ExpectCopyClose(suite.mock_db)
w := doRequest("/health")
suite.Equal(http.StatusNoContent, w.Code)
}
func (suite *GetHealthTestSuite) TestGetHealthTriesToReconnectsToDataBase() {
suite.mock_db.On("Ping").Return(errors.New("ping error"))
ExpectCopyCloseReconnect(suite.mock_db)
ExpectReconnect(suite.mock_db)
w := doRequest("/health")
suite.Equal(http.StatusNoContent, w.Code)
......
......@@ -15,11 +15,6 @@ func TestGetIdWithoutDatabaseName(t *testing.T) {
assert.Equal(t, http.StatusNotFound, w.Code, "no database name")
}
func ExpectCopyCloseOnID(mock_db *database.MockedDatabase) {
mock_db.On("Copy").Return(mock_db)
mock_db.On("Close").Return()
}
type GetIDTestSuite struct {
suite.Suite
mock_db *database.MockedDatabase
......@@ -31,7 +26,6 @@ func (suite *GetIDTestSuite) SetupTest() {
suite.mock_db = new(database.MockedDatabase)
db = suite.mock_db
logger.SetMockLog()
ExpectCopyCloseOnID(suite.mock_db)
}
func (suite *GetIDTestSuite) TearDownTest() {
......@@ -47,7 +41,6 @@ func TestGetIDTestSuite(t *testing.T) {
func (suite *GetIDTestSuite) TestGetIdCallsCorrectRoutine() {
suite.mock_db.On("ProcessRequest", expectedDBName, expectedGroupID, "id", "1").Return([]byte("Hello"), nil)
logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("processing request")))
ExpectCopyClose(suite.mock_db)
w := doRequest("/database/" + expectedBeamtimeId + "/" + expectedStream + "/" + expectedGroupID + "/1" + correctTokenSuffix)
suite.Equal(http.StatusOK, w.Code, "GetImage OK")
......
......@@ -35,7 +35,6 @@ func TestGetLastTestSuite(t *testing.T) {
func (suite *GetLastTestSuite) TestGetLastCallsCorrectRoutine() {
suite.mock_db.On("ProcessRequest", expectedDBName, expectedGroupID, "last", "0").Return([]byte("Hello"), nil)
logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("processing request last")))
ExpectCopyClose(suite.mock_db)
w := doRequest("/database/" + expectedBeamtimeId + "/" + expectedStream + "/" + expectedGroupID + "/last" + correctTokenSuffix)
suite.Equal(http.StatusOK, w.Code, "GetLast OK")
......
......@@ -35,7 +35,6 @@ func TestGetMetaTestSuite(t *testing.T) {
func (suite *GetMetaTestSuite) TestGetMetaOK() {
suite.mock_db.On("ProcessRequest", expectedDBName, "", "meta", "0").Return([]byte("{\"test\":10}"), nil)
logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("processing request meta")))
ExpectCopyClose(suite.mock_db)
w := doRequest("/database/" + expectedBeamtimeId + "/" + expectedStream + "/0/meta/0" + correctTokenSuffix)
suite.Equal(http.StatusOK, w.Code, "GetSize OK")
......
......@@ -35,7 +35,6 @@ func TestGetNextTestSuite(t *testing.T) {
func (suite *GetNextTestSuite) TestGetNextCallsCorrectRoutine() {
suite.mock_db.On("ProcessRequest", expectedDBName, expectedGroupID, "next", "0").Return([]byte("Hello"), nil)
logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("processing request next")))
ExpectCopyClose(suite.mock_db)
w := doRequest("/database/" + expectedBeamtimeId + "/" + expectedStream + "/" + expectedGroupID + "/next" + correctTokenSuffix)
suite.Equal(http.StatusOK, w.Code, "GetNext OK")
......
......@@ -35,7 +35,6 @@ func TestGetSizeTestSuite(t *testing.T) {
func (suite *GetSizeTestSuite) TestGetSizeOK() {
suite.mock_db.On("ProcessRequest", expectedDBName, "", "size", "0").Return([]byte("{\"size\":10}"), nil)
logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("processing request size")))
ExpectCopyClose(suite.mock_db)
w := doRequest("/database/" + expectedBeamtimeId + "/" + expectedStream + "/size" + correctTokenSuffix)
suite.Equal(http.StatusOK, w.Code, "GetSize OK")
......
......@@ -36,7 +36,6 @@ func (suite *QueryTestSuite) TestQueryOK() {
query_str := "aaaa"
suite.mock_db.On("ProcessRequest", expectedDBName, "", "queryimages", query_str).Return([]byte("{}"), nil)
logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("processing request queryimages")))
ExpectCopyClose(suite.mock_db)
w := doRequest("/database/"+expectedBeamtimeId+"/"+expectedStream+"/0/queryimages"+correctTokenSuffix, "POST", query_str)
suite.Equal(http.StatusOK, w.Code, "Query OK")
......
......@@ -35,7 +35,6 @@ func TestResetCounterTestSuite(t *testing.T) {
func (suite *ResetCounterTestSuite) TestResetCounterOK() {
suite.mock_db.On("ProcessRequest", expectedDBName, expectedGroupID, "resetcounter", "10").Return([]byte(""), nil)
logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("processing request resetcounter")))
ExpectCopyClose(suite.mock_db)
w := doRequest("/database/"+expectedBeamtimeId+"/"+expectedStream+"/"+expectedGroupID+"/resetcounter"+correctTokenSuffix+"&value=10", "POST")
suite.Equal(http.StatusOK, w.Code, "ResetCounter OK")
......
......@@ -99,10 +99,8 @@ func reconnectIfNeeded(db_error error) {
}
func processRequestInDb(db_name string, group_id string, op string, extra_param string) (answer []byte, code int) {
db_new := db.Copy()
defer db_new.Close()
statistics.IncreaseCounter()
answer, err := db_new.ProcessRequest(db_name, group_id, op, extra_param)
answer, err := db.ProcessRequest(db_name, group_id, op, extra_param)
log_str := "processing request " + op + " in " + db_name + " at " + settings.GetDatabaseServer()
if err != nil {
go reconnectIfNeeded(err)
......
......@@ -75,14 +75,8 @@ func TestProcessRequestWithoutDatabaseName(t *testing.T) {
assert.Equal(t, http.StatusNotFound, w.Code, "no database name")
}
func ExpectCopyClose(mock_db *database.MockedDatabase) {
mock_db.On("Copy").Return(mock_db)
func ExpectReconnect(mock_db *database.MockedDatabase) {
mock_db.On("Close").Return()
}
func ExpectCopyCloseReconnect(mock_db *database.MockedDatabase) {
mock_db.On("Copy").Return(mock_db)
mock_db.On("Close").Twice().Return()
mock_db.On("Connect", mock.AnythingOfType("string")).Return(nil)
}
......@@ -130,7 +124,6 @@ func (suite *ProcessRequestTestSuite) TestProcessRequestWithWrongDatabaseName()
&database.DBError{utils.StatusWrongInput, ""})
logger.MockLog.On("Error", mock.MatchedBy(containsMatcher("processing request next")))
ExpectCopyClose(suite.mock_db)
w := doRequest("/database/" + expectedBeamtimeId + "/" + expectedStream + "/" + expectedGroupID + "/next" + correctTokenSuffix)
......@@ -142,7 +135,7 @@ func (suite *ProcessRequestTestSuite) TestProcessRequestWithConnectionError() {
&database.DBError{utils.StatusServiceUnavailable, ""})
logger.MockLog.On("Error", mock.MatchedBy(containsMatcher("processing request next")))
ExpectCopyCloseReconnect(suite.mock_db)
ExpectReconnect(suite.mock_db)
logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("reconnected")))
w := doRequest("/database/" + expectedBeamtimeId + "/" + expectedStream + "/" + expectedGroupID + "/next" + correctTokenSuffix)
......@@ -155,7 +148,7 @@ func (suite *ProcessRequestTestSuite) TestProcessRequestWithInternalDBError() {
logger.MockLog.On("Error", mock.MatchedBy(containsMatcher("processing request next")))
logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("reconnected")))
ExpectCopyCloseReconnect(suite.mock_db)
ExpectReconnect(suite.mock_db)
w := doRequest("/database/" + expectedBeamtimeId + "/" + expectedStream + "/" + expectedGroupID + "/next" + correctTokenSuffix)
time.Sleep(time.Second)
......@@ -165,7 +158,6 @@ func (suite *ProcessRequestTestSuite) TestProcessRequestWithInternalDBError() {
func (suite *ProcessRequestTestSuite) TestProcessRequestAddsCounter() {
suite.mock_db.On("ProcessRequest", expectedDBName, expectedGroupID, "next", "0").Return([]byte("Hello"), nil)
logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("processing request next in "+expectedDBName)))
ExpectCopyClose(suite.mock_db)
doRequest("/database/" + expectedBeamtimeId + "/" + expectedStream + "/" + expectedGroupID + "/next" + correctTokenSuffix)
suite.Equal(1, statistics.GetCounter(), "ProcessRequest increases counter")
......@@ -180,7 +172,6 @@ func (suite *ProcessRequestTestSuite) TestProcessRequestWrongGroupID() {
func (suite *ProcessRequestTestSuite) TestProcessRequestAddsDataset() {
suite.mock_db.On("ProcessRequest", expectedDBName, expectedGroupID, "next_dataset", "0").Return([]byte("Hello"), nil)
logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("processing request next_dataset in "+expectedDBName)))
ExpectCopyClose(suite.mock_db)
doRequest("/database/" + expectedBeamtimeId + "/" + expectedStream + "/" + expectedGroupID + "/next" + correctTokenSuffix + "&dataset=true")
}
File added
12ljzgneasfd
{
"Port": {{ env "NOMAD_PORT_authorizer" }},
"LogLevel":"debug",
"AlwaysAllowedBeamtimes":[{"BeamtimeId":"asapo_test","Beamline":"test"},
{"BeamtimeId":"asapo_test1","Beamline":"test1"},
{"BeamtimeId":"asapo_test2","Beamline":"test2"}],
"SecretFile":"auth_secret.key"
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment