From 8d5881f61fabbd6656ba28b7110287a75cbd42b2 Mon Sep 17 00:00:00 2001
From: Sergey Yakubov <sergey.yakubov@desy.de>
Date: Wed, 25 Sep 2019 23:22:34 +0200
Subject: [PATCH] refactor mongodb errors

---
 broker/src/asapo_broker/database/mongodb.go   | 53 +++++++++++--------
 .../src/asapo_broker/database/mongodb_test.go |  4 +-
 .../asapo_broker/server/process_request.go    |  2 +-
 .../server/process_request_test.go            |  2 +-
 .../go/src/asapo_common/utils/status_codes.go |  7 +--
 .../consumer_api_python/consumer_api.py       |  2 +-
 6 files changed, 41 insertions(+), 29 deletions(-)

diff --git a/broker/src/asapo_broker/database/mongodb.go b/broker/src/asapo_broker/database/mongodb.go
index 2ccfd229f..dc3b81ace 100644
--- a/broker/src/asapo_broker/database/mongodb.go
+++ b/broker/src/asapo_broker/database/mongodb.go
@@ -73,7 +73,7 @@ func (db *Mongodb) dataBaseExist(dbname string) (err error) {
 	}
 
 	if !db.databaseInList(dbname) {
-		return errors.New("dataset not found: " + dbname)
+		return &DBError{utils.StatusWrongInput, "stream not found: " + dbname}
 	}
 
 	return nil
@@ -81,7 +81,7 @@ func (db *Mongodb) dataBaseExist(dbname string) (err error) {
 
 func (db *Mongodb) Connect(address string) (err error) {
 	if db.session != nil {
-		return errors.New(already_connected_msg)
+		return &DBError{utils.StatusServiceUnavailable, already_connected_msg}
 	}
 
 	db.session, err = mgo.DialWithTimeout(address, time.Second)
@@ -108,14 +108,14 @@ func (db *Mongodb) Close() {
 
 func (db *Mongodb) DeleteAllRecords(dbname string) (err error) {
 	if db.session == nil {
-		return errors.New(no_session_msg)
+		return &DBError{utils.StatusServiceUnavailable, no_session_msg}
 	}
 	return db.session.DB(dbname).DropDatabase()
 }
 
 func (db *Mongodb) InsertRecord(dbname string, s interface{}) error {
 	if db.session == nil {
-		return errors.New(no_session_msg)
+		return &DBError{utils.StatusServiceUnavailable, no_session_msg}
 	}
 
 	c := db.session.DB(dbname).C(data_collection_name)
@@ -125,7 +125,7 @@ func (db *Mongodb) InsertRecord(dbname string, s interface{}) error {
 
 func (db *Mongodb) InsertMeta(dbname string, s interface{}) error {
 	if db.session == nil {
-		return errors.New(no_session_msg)
+		return &DBError{utils.StatusServiceUnavailable, no_session_msg}
 	}
 
 	c := db.session.DB(dbname).C(meta_collection_name)
@@ -133,7 +133,7 @@ func (db *Mongodb) InsertMeta(dbname string, s interface{}) error {
 	return c.Insert(s)
 }
 
-func (db *Mongodb) getMaxIndex(dbname string, dataset bool) (max_id int) {
+func (db *Mongodb) getMaxIndex(dbname string, dataset bool) (max_id int,err error) {
 	c := db.session.DB(dbname).C(data_collection_name)
 	var id Pointer
 	var q bson.M
@@ -142,11 +142,11 @@ func (db *Mongodb) getMaxIndex(dbname string, dataset bool) (max_id int) {
 	} else {
 		q = nil
 	}
-	err := c.Find(q).Sort("-_id").Select(bson.M{"_id": 1}).One(&id)
-	if err != nil {
-		return 0
+	err = c.Find(q).Sort("-_id").Select(bson.M{"_id": 1}).One(&id)
+	if err == mgo.ErrNotFound {
+		return 0,nil
 	}
-	return id.ID
+	return id.ID,err
 }
 
 func (db *Mongodb) createLocationPointers(dbname string, group_id string) (err error) {
@@ -179,8 +179,10 @@ func (db *Mongodb) incrementField(dbname string, group_id string, max_ind int, r
 	_, err = c.Find(q).Apply(change, res)
 	if err == mgo.ErrNotFound {
 		return &DBError{utils.StatusNoData, encodeAnswer(max_ind, max_ind)}
+	} else if err !=nil { // we do not know if counter was updated
+		return &DBError{utils.StatusTransactionInterrupted, err.Error()}
 	}
-	return err
+	return nil
 }
 
 func encodeAnswer(id, id_max int) string {
@@ -218,17 +220,20 @@ func (db *Mongodb) GetRecordByIDRow(dbname string, id, id_max int, dataset bool)
 func (db *Mongodb) GetRecordByID(dbname string, group_id string, id_str string, dataset bool) ([]byte, error) {
 	id, err := strconv.Atoi(id_str)
 	if err != nil {
-		return nil, err
+		return nil, &DBError{utils.StatusWrongInput, err.Error()}
 	}
 
 	if err := db.checkDatabaseOperationPrerequisites(dbname, group_id); err != nil {
 		return nil, err
 	}
 
-	max_ind := db.getMaxIndex(dbname, dataset)
-	res, err := db.GetRecordByIDRow(dbname, id, max_ind, dataset)
+	max_ind,err := db.getMaxIndex(dbname, dataset)
+	if err != nil {
+		return nil,err
+	}
+
+	return  db.GetRecordByIDRow(dbname, id, max_ind, dataset)
 
-	return res, err
 }
 
 func (db *Mongodb) needCreateLocationPointersInDb(group_id string) bool {
@@ -264,11 +269,11 @@ func (db *Mongodb) getParentDB() *Mongodb {
 
 func (db *Mongodb) checkDatabaseOperationPrerequisites(db_name string, group_id string) error {
 	if db.session == nil {
-		return &DBError{utils.StatusError, no_session_msg}
+		return &DBError{utils.StatusServiceUnavailable, no_session_msg}
 	}
 
 	if err := db.getParentDB().dataBaseExist(db_name); err != nil {
-		return &DBError{utils.StatusWrongInput, err.Error()}
+		return err
 	}
 
 	if len(group_id) > 0 {
@@ -278,10 +283,13 @@ func (db *Mongodb) checkDatabaseOperationPrerequisites(db_name string, group_id
 }
 
 func (db *Mongodb) getCurrentPointer(db_name string, group_id string, dataset bool) (Pointer, int, error) {
-	max_ind := db.getMaxIndex(db_name, dataset)
+	max_ind,err := db.getMaxIndex(db_name, dataset)
+	if err != nil {
+		return Pointer{}, 0, err
+	}
 
 	var curPointer Pointer
-	err := db.incrementField(db_name, group_id, max_ind, &curPointer)
+	err = db.incrementField(db_name, group_id, max_ind, &curPointer)
 	if err != nil {
 		return Pointer{}, 0, err
 	}
@@ -303,16 +311,19 @@ func (db *Mongodb) GetNextRecord(db_name string, group_id string, dataset bool)
 	log_str := "got next pointer " + strconv.Itoa(curPointer.Value) + " for " + db_name + ", groupid: " + group_id
 	logger.Debug(log_str)
 	return db.GetRecordByIDRow(db_name, curPointer.Value, max_ind, dataset)
-
 }
 
+
 func (db *Mongodb) GetLastRecord(db_name string, group_id string, dataset bool) ([]byte, error) {
 
 	if err := db.checkDatabaseOperationPrerequisites(db_name, group_id); err != nil {
 		return nil, err
 	}
 
-	max_ind := db.getMaxIndex(db_name, dataset)
+	max_ind,err := db.getMaxIndex(db_name, dataset)
+	if err !=nil {
+		return nil,err
+	}
 	res, err := db.GetRecordByIDRow(db_name, max_ind, max_ind, dataset)
 
 	db.setCounter(db_name, group_id, max_ind)
diff --git a/broker/src/asapo_broker/database/mongodb_test.go b/broker/src/asapo_broker/database/mongodb_test.go
index 12a4ba197..92e84b252 100644
--- a/broker/src/asapo_broker/database/mongodb_test.go
+++ b/broker/src/asapo_broker/database/mongodb_test.go
@@ -67,7 +67,7 @@ func TestMongoDBConnectOK(t *testing.T) {
 
 func TestMongoDBGetNextErrorWhenNotConnected(t *testing.T) {
 	_, err := db.GetNextRecord("", groupId, false)
-	assert.Equal(t, utils.StatusError, err.(*DBError).Code)
+	assert.Equal(t, utils.StatusServiceUnavailable, err.(*DBError).Code)
 }
 
 func TestMongoDBGetNextErrorWhenWrongDatabasename(t *testing.T) {
@@ -287,7 +287,7 @@ func TestMongoDBGetSizeNoDatabase(t *testing.T) {
 
 func TestMongoDBGetRecordByIDNotConnected(t *testing.T) {
 	_, err := db.GetRecordByID(dbname, "", "2", false)
-	assert.Equal(t, utils.StatusError, err.(*DBError).Code)
+	assert.Equal(t, utils.StatusServiceUnavailable, err.(*DBError).Code)
 }
 
 func TestMongoDBResetCounter(t *testing.T) {
diff --git a/broker/src/asapo_broker/server/process_request.go b/broker/src/asapo_broker/server/process_request.go
index a912998bc..e7a3df859 100644
--- a/broker/src/asapo_broker/server/process_request.go
+++ b/broker/src/asapo_broker/server/process_request.go
@@ -68,8 +68,8 @@ func processRequest(w http.ResponseWriter, r *http.Request, op string, extra_par
 }
 
 func returnError(err error, log_str string) (answer []byte, code int) {
+	code = utils.StatusServiceUnavailable
 	err_db, ok := err.(*database.DBError)
-	code = utils.StatusError
 	if ok {
 		code = err_db.Code
 	}
diff --git a/broker/src/asapo_broker/server/process_request_test.go b/broker/src/asapo_broker/server/process_request_test.go
index 65816108f..c5c6c8cb9 100644
--- a/broker/src/asapo_broker/server/process_request_test.go
+++ b/broker/src/asapo_broker/server/process_request_test.go
@@ -136,7 +136,7 @@ func (suite *ProcessRequestTestSuite) TestProcessRequestWithInternalDBError() {
 	ExpectCopyClose(suite.mock_db)
 
 	w := doRequest("/database/" + expectedBeamtimeId + "/" + expectedStream + "/" + expectedGroupID + "/next" + correctTokenSuffix)
-	suite.Equal(http.StatusInternalServerError, w.Code, "internal error")
+	suite.Equal(http.StatusNotFound, w.Code, "internal error")
 }
 
 func (suite *ProcessRequestTestSuite) TestProcessRequestAddsCounter() {
diff --git a/common/go/src/asapo_common/utils/status_codes.go b/common/go/src/asapo_common/utils/status_codes.go
index 58fef4da3..9f6e06162 100644
--- a/common/go/src/asapo_common/utils/status_codes.go
+++ b/common/go/src/asapo_common/utils/status_codes.go
@@ -8,7 +8,8 @@ const (
 )
 const (
 	//error codes
-	StatusError      = http.StatusInternalServerError
-	StatusWrongInput = http.StatusBadRequest
-	StatusNoData     = http.StatusConflict
+	StatusTransactionInterrupted = http.StatusInternalServerError
+	StatusServiceUnavailable	 = http.StatusNotFound
+	StatusWrongInput             = http.StatusBadRequest
+	StatusNoData                 = http.StatusConflict
 )
diff --git a/tests/automatic/consumer/consumer_api_python/consumer_api.py b/tests/automatic/consumer/consumer_api_python/consumer_api.py
index a44fea88b..06c6e1cb5 100644
--- a/tests/automatic/consumer/consumer_api_python/consumer_api.py
+++ b/tests/automatic/consumer/consumer_api_python/consumer_api.py
@@ -33,7 +33,7 @@ def assert_eq(val,expected,name):
 def check_broker_server_error(broker,group_id_new):
     try:
         broker.get_last(group_id_new, meta_only=True)
-    except asapo_consumer.AsapoBrokerServerError as err:
+    except asapo_consumer.AsapoBrokerServersNotFound as err:
         print(err)
         pass
     else:
-- 
GitLab