Skip to content
Snippets Groups Projects
Commit 8d5881f6 authored by Sergey Yakubov's avatar Sergey Yakubov
Browse files

refactor mongodb errors

parent b46bf564
No related branches found
No related tags found
No related merge requests found
...@@ -73,7 +73,7 @@ func (db *Mongodb) dataBaseExist(dbname string) (err error) { ...@@ -73,7 +73,7 @@ func (db *Mongodb) dataBaseExist(dbname string) (err error) {
} }
if !db.databaseInList(dbname) { if !db.databaseInList(dbname) {
return errors.New("dataset not found: " + dbname) return &DBError{utils.StatusWrongInput, "stream not found: " + dbname}
} }
return nil return nil
...@@ -81,7 +81,7 @@ func (db *Mongodb) dataBaseExist(dbname string) (err error) { ...@@ -81,7 +81,7 @@ func (db *Mongodb) dataBaseExist(dbname string) (err error) {
func (db *Mongodb) Connect(address string) (err error) { func (db *Mongodb) Connect(address string) (err error) {
if db.session != nil { if db.session != nil {
return errors.New(already_connected_msg) return &DBError{utils.StatusServiceUnavailable, already_connected_msg}
} }
db.session, err = mgo.DialWithTimeout(address, time.Second) db.session, err = mgo.DialWithTimeout(address, time.Second)
...@@ -108,14 +108,14 @@ func (db *Mongodb) Close() { ...@@ -108,14 +108,14 @@ func (db *Mongodb) Close() {
func (db *Mongodb) DeleteAllRecords(dbname string) (err error) { func (db *Mongodb) DeleteAllRecords(dbname string) (err error) {
if db.session == nil { if db.session == nil {
return errors.New(no_session_msg) return &DBError{utils.StatusServiceUnavailable, no_session_msg}
} }
return db.session.DB(dbname).DropDatabase() return db.session.DB(dbname).DropDatabase()
} }
func (db *Mongodb) InsertRecord(dbname string, s interface{}) error { func (db *Mongodb) InsertRecord(dbname string, s interface{}) error {
if db.session == nil { if db.session == nil {
return errors.New(no_session_msg) return &DBError{utils.StatusServiceUnavailable, no_session_msg}
} }
c := db.session.DB(dbname).C(data_collection_name) c := db.session.DB(dbname).C(data_collection_name)
...@@ -125,7 +125,7 @@ func (db *Mongodb) InsertRecord(dbname string, s interface{}) error { ...@@ -125,7 +125,7 @@ func (db *Mongodb) InsertRecord(dbname string, s interface{}) error {
func (db *Mongodb) InsertMeta(dbname string, s interface{}) error { func (db *Mongodb) InsertMeta(dbname string, s interface{}) error {
if db.session == nil { if db.session == nil {
return errors.New(no_session_msg) return &DBError{utils.StatusServiceUnavailable, no_session_msg}
} }
c := db.session.DB(dbname).C(meta_collection_name) c := db.session.DB(dbname).C(meta_collection_name)
...@@ -133,7 +133,7 @@ func (db *Mongodb) InsertMeta(dbname string, s interface{}) error { ...@@ -133,7 +133,7 @@ func (db *Mongodb) InsertMeta(dbname string, s interface{}) error {
return c.Insert(s) return c.Insert(s)
} }
func (db *Mongodb) getMaxIndex(dbname string, dataset bool) (max_id int) { func (db *Mongodb) getMaxIndex(dbname string, dataset bool) (max_id int,err error) {
c := db.session.DB(dbname).C(data_collection_name) c := db.session.DB(dbname).C(data_collection_name)
var id Pointer var id Pointer
var q bson.M var q bson.M
...@@ -142,11 +142,11 @@ func (db *Mongodb) getMaxIndex(dbname string, dataset bool) (max_id int) { ...@@ -142,11 +142,11 @@ func (db *Mongodb) getMaxIndex(dbname string, dataset bool) (max_id int) {
} else { } else {
q = nil q = nil
} }
err := c.Find(q).Sort("-_id").Select(bson.M{"_id": 1}).One(&id) err = c.Find(q).Sort("-_id").Select(bson.M{"_id": 1}).One(&id)
if err != nil { if err == mgo.ErrNotFound {
return 0 return 0,nil
} }
return id.ID return id.ID,err
} }
func (db *Mongodb) createLocationPointers(dbname string, group_id string) (err error) { func (db *Mongodb) createLocationPointers(dbname string, group_id string) (err error) {
...@@ -179,8 +179,10 @@ func (db *Mongodb) incrementField(dbname string, group_id string, max_ind int, r ...@@ -179,8 +179,10 @@ func (db *Mongodb) incrementField(dbname string, group_id string, max_ind int, r
_, err = c.Find(q).Apply(change, res) _, err = c.Find(q).Apply(change, res)
if err == mgo.ErrNotFound { if err == mgo.ErrNotFound {
return &DBError{utils.StatusNoData, encodeAnswer(max_ind, max_ind)} return &DBError{utils.StatusNoData, encodeAnswer(max_ind, max_ind)}
} else if err !=nil { // we do not know if counter was updated
return &DBError{utils.StatusTransactionInterrupted, err.Error()}
} }
return err return nil
} }
func encodeAnswer(id, id_max int) string { func encodeAnswer(id, id_max int) string {
...@@ -218,17 +220,20 @@ func (db *Mongodb) GetRecordByIDRow(dbname string, id, id_max int, dataset bool) ...@@ -218,17 +220,20 @@ func (db *Mongodb) GetRecordByIDRow(dbname string, id, id_max int, dataset bool)
func (db *Mongodb) GetRecordByID(dbname string, group_id string, id_str string, dataset bool) ([]byte, error) { func (db *Mongodb) GetRecordByID(dbname string, group_id string, id_str string, dataset bool) ([]byte, error) {
id, err := strconv.Atoi(id_str) id, err := strconv.Atoi(id_str)
if err != nil { if err != nil {
return nil, err return nil, &DBError{utils.StatusWrongInput, err.Error()}
} }
if err := db.checkDatabaseOperationPrerequisites(dbname, group_id); err != nil { if err := db.checkDatabaseOperationPrerequisites(dbname, group_id); err != nil {
return nil, err return nil, err
} }
max_ind := db.getMaxIndex(dbname, dataset) max_ind,err := db.getMaxIndex(dbname, dataset)
res, err := db.GetRecordByIDRow(dbname, id, max_ind, dataset) if err != nil {
return nil,err
}
return db.GetRecordByIDRow(dbname, id, max_ind, dataset)
return res, err
} }
func (db *Mongodb) needCreateLocationPointersInDb(group_id string) bool { func (db *Mongodb) needCreateLocationPointersInDb(group_id string) bool {
...@@ -264,11 +269,11 @@ func (db *Mongodb) getParentDB() *Mongodb { ...@@ -264,11 +269,11 @@ func (db *Mongodb) getParentDB() *Mongodb {
func (db *Mongodb) checkDatabaseOperationPrerequisites(db_name string, group_id string) error { func (db *Mongodb) checkDatabaseOperationPrerequisites(db_name string, group_id string) error {
if db.session == nil { if db.session == nil {
return &DBError{utils.StatusError, no_session_msg} return &DBError{utils.StatusServiceUnavailable, no_session_msg}
} }
if err := db.getParentDB().dataBaseExist(db_name); err != nil { if err := db.getParentDB().dataBaseExist(db_name); err != nil {
return &DBError{utils.StatusWrongInput, err.Error()} return err
} }
if len(group_id) > 0 { if len(group_id) > 0 {
...@@ -278,10 +283,13 @@ func (db *Mongodb) checkDatabaseOperationPrerequisites(db_name string, group_id ...@@ -278,10 +283,13 @@ func (db *Mongodb) checkDatabaseOperationPrerequisites(db_name string, group_id
} }
func (db *Mongodb) getCurrentPointer(db_name string, group_id string, dataset bool) (Pointer, int, error) { func (db *Mongodb) getCurrentPointer(db_name string, group_id string, dataset bool) (Pointer, int, error) {
max_ind := db.getMaxIndex(db_name, dataset) max_ind,err := db.getMaxIndex(db_name, dataset)
if err != nil {
return Pointer{}, 0, err
}
var curPointer Pointer var curPointer Pointer
err := db.incrementField(db_name, group_id, max_ind, &curPointer) err = db.incrementField(db_name, group_id, max_ind, &curPointer)
if err != nil { if err != nil {
return Pointer{}, 0, err return Pointer{}, 0, err
} }
...@@ -303,16 +311,19 @@ func (db *Mongodb) GetNextRecord(db_name string, group_id string, dataset bool) ...@@ -303,16 +311,19 @@ func (db *Mongodb) GetNextRecord(db_name string, group_id string, dataset bool)
log_str := "got next pointer " + strconv.Itoa(curPointer.Value) + " for " + db_name + ", groupid: " + group_id log_str := "got next pointer " + strconv.Itoa(curPointer.Value) + " for " + db_name + ", groupid: " + group_id
logger.Debug(log_str) logger.Debug(log_str)
return db.GetRecordByIDRow(db_name, curPointer.Value, max_ind, dataset) return db.GetRecordByIDRow(db_name, curPointer.Value, max_ind, dataset)
} }
func (db *Mongodb) GetLastRecord(db_name string, group_id string, dataset bool) ([]byte, error) { func (db *Mongodb) GetLastRecord(db_name string, group_id string, dataset bool) ([]byte, error) {
if err := db.checkDatabaseOperationPrerequisites(db_name, group_id); err != nil { if err := db.checkDatabaseOperationPrerequisites(db_name, group_id); err != nil {
return nil, err return nil, err
} }
max_ind := db.getMaxIndex(db_name, dataset) max_ind,err := db.getMaxIndex(db_name, dataset)
if err !=nil {
return nil,err
}
res, err := db.GetRecordByIDRow(db_name, max_ind, max_ind, dataset) res, err := db.GetRecordByIDRow(db_name, max_ind, max_ind, dataset)
db.setCounter(db_name, group_id, max_ind) db.setCounter(db_name, group_id, max_ind)
......
...@@ -67,7 +67,7 @@ func TestMongoDBConnectOK(t *testing.T) { ...@@ -67,7 +67,7 @@ func TestMongoDBConnectOK(t *testing.T) {
func TestMongoDBGetNextErrorWhenNotConnected(t *testing.T) { func TestMongoDBGetNextErrorWhenNotConnected(t *testing.T) {
_, err := db.GetNextRecord("", groupId, false) _, err := db.GetNextRecord("", groupId, false)
assert.Equal(t, utils.StatusError, err.(*DBError).Code) assert.Equal(t, utils.StatusServiceUnavailable, err.(*DBError).Code)
} }
func TestMongoDBGetNextErrorWhenWrongDatabasename(t *testing.T) { func TestMongoDBGetNextErrorWhenWrongDatabasename(t *testing.T) {
...@@ -287,7 +287,7 @@ func TestMongoDBGetSizeNoDatabase(t *testing.T) { ...@@ -287,7 +287,7 @@ func TestMongoDBGetSizeNoDatabase(t *testing.T) {
func TestMongoDBGetRecordByIDNotConnected(t *testing.T) { func TestMongoDBGetRecordByIDNotConnected(t *testing.T) {
_, err := db.GetRecordByID(dbname, "", "2", false) _, err := db.GetRecordByID(dbname, "", "2", false)
assert.Equal(t, utils.StatusError, err.(*DBError).Code) assert.Equal(t, utils.StatusServiceUnavailable, err.(*DBError).Code)
} }
func TestMongoDBResetCounter(t *testing.T) { func TestMongoDBResetCounter(t *testing.T) {
......
...@@ -68,8 +68,8 @@ func processRequest(w http.ResponseWriter, r *http.Request, op string, extra_par ...@@ -68,8 +68,8 @@ func processRequest(w http.ResponseWriter, r *http.Request, op string, extra_par
} }
func returnError(err error, log_str string) (answer []byte, code int) { func returnError(err error, log_str string) (answer []byte, code int) {
code = utils.StatusServiceUnavailable
err_db, ok := err.(*database.DBError) err_db, ok := err.(*database.DBError)
code = utils.StatusError
if ok { if ok {
code = err_db.Code code = err_db.Code
} }
......
...@@ -136,7 +136,7 @@ func (suite *ProcessRequestTestSuite) TestProcessRequestWithInternalDBError() { ...@@ -136,7 +136,7 @@ func (suite *ProcessRequestTestSuite) TestProcessRequestWithInternalDBError() {
ExpectCopyClose(suite.mock_db) ExpectCopyClose(suite.mock_db)
w := doRequest("/database/" + expectedBeamtimeId + "/" + expectedStream + "/" + expectedGroupID + "/next" + correctTokenSuffix) w := doRequest("/database/" + expectedBeamtimeId + "/" + expectedStream + "/" + expectedGroupID + "/next" + correctTokenSuffix)
suite.Equal(http.StatusInternalServerError, w.Code, "internal error") suite.Equal(http.StatusNotFound, w.Code, "internal error")
} }
func (suite *ProcessRequestTestSuite) TestProcessRequestAddsCounter() { func (suite *ProcessRequestTestSuite) TestProcessRequestAddsCounter() {
......
...@@ -8,7 +8,8 @@ const ( ...@@ -8,7 +8,8 @@ const (
) )
const ( const (
//error codes //error codes
StatusError = http.StatusInternalServerError StatusTransactionInterrupted = http.StatusInternalServerError
StatusWrongInput = http.StatusBadRequest StatusServiceUnavailable = http.StatusNotFound
StatusNoData = http.StatusConflict StatusWrongInput = http.StatusBadRequest
StatusNoData = http.StatusConflict
) )
...@@ -33,7 +33,7 @@ def assert_eq(val,expected,name): ...@@ -33,7 +33,7 @@ def assert_eq(val,expected,name):
def check_broker_server_error(broker,group_id_new): def check_broker_server_error(broker,group_id_new):
try: try:
broker.get_last(group_id_new, meta_only=True) broker.get_last(group_id_new, meta_only=True)
except asapo_consumer.AsapoBrokerServerError as err: except asapo_consumer.AsapoBrokerServersNotFound as err:
print(err) print(err)
pass pass
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment