diff --git a/broker/src/asapo_broker/database/mongodb.go b/broker/src/asapo_broker/database/mongodb.go index d20b44223afe5c76bcb971173ca187db4d6d761a..64dc01e2e9827cf9a15dd46d3887d4897e357530 100644 --- a/broker/src/asapo_broker/database/mongodb.go +++ b/broker/src/asapo_broker/database/mongodb.go @@ -10,6 +10,7 @@ import ( "github.com/globalsign/mgo" "github.com/globalsign/mgo/bson" "strconv" + "strings" "sync" "time" ) @@ -132,10 +133,16 @@ func (db *Mongodb) InsertMeta(dbname string, s interface{}) error { return c.Insert(s) } -func (db *Mongodb) getMaxIndex(dbname string) (max_id int, err error) { +func (db *Mongodb) getMaxIndex(dbname string, dataset bool) (max_id int, err error) { c := db.session.DB(dbname).C(data_collection_name) var id Pointer - err = c.Find(nil).Sort("-_id").Select(bson.M{"_id": 1}).One(&id) + var q bson.M + if dataset { + q = bson.M{"$expr": bson.M{"$eq": []interface{}{"$size", bson.M{"$size": "$images"}}}} + } else { + q = nil + } + err = c.Find(q).Sort("-_id").Select(bson.M{"_id": 1}).One(&id) if err != nil { return 0, nil } @@ -176,9 +183,15 @@ func (db *Mongodb) incrementField(dbname string, group_id string, max_ind int, r return err } -func (db *Mongodb) GetRecordByIDRow(dbname string, id int, returnID bool) ([]byte, error) { +func (db *Mongodb) GetRecordByIDRow(dbname string, id int, returnID bool, dataset bool) ([]byte, error) { var res map[string]interface{} - q := bson.M{"_id": id} + var q bson.M + if dataset { + q = bson.M{"$and": []bson.M{bson.M{"_id": id}, bson.M{"$expr": bson.M{"$eq": []interface{}{"$size", bson.M{"$size": "$images"}}}}}} + } else { + q = bson.M{"_id": id} + } + c := db.session.DB(dbname).C(data_collection_name) err := c.Find(q).One(&res) if err != nil { @@ -200,7 +213,7 @@ func (db *Mongodb) GetRecordByIDRow(dbname string, id int, returnID bool) ([]byt return utils.MapToJson(&res) } -func (db *Mongodb) GetRecordByID(dbname string, group_id string, id_str string, returnID bool, reset bool) ([]byte, error) { +func (db *Mongodb) GetRecordByID(dbname string, group_id string, id_str string, returnID bool, reset bool, dataset bool) ([]byte, error) { id, err := strconv.Atoi(id_str) if err != nil { return nil, err @@ -209,7 +222,7 @@ func (db *Mongodb) GetRecordByID(dbname string, group_id string, id_str string, if err := db.checkDatabaseOperationPrerequisites(dbname, group_id); err != nil { return nil, err } - res, err := db.GetRecordByIDRow(dbname, id, returnID) + res, err := db.GetRecordByIDRow(dbname, id, returnID, dataset) if reset { db.setCounter(dbname, group_id, id) @@ -265,7 +278,7 @@ func (db *Mongodb) checkDatabaseOperationPrerequisites(db_name string, group_id } func (db *Mongodb) getCurrentPointer(db_name string, group_id string) (Pointer, error) { - max_ind, err := db.getMaxIndex(db_name) + max_ind, err := db.getMaxIndex(db_name, false) if err != nil { return Pointer{}, err } @@ -278,7 +291,7 @@ func (db *Mongodb) getCurrentPointer(db_name string, group_id string) (Pointer, return curPointer, nil } -func (db *Mongodb) GetNextRecord(db_name string, group_id string) ([]byte, error) { +func (db *Mongodb) GetNextRecord(db_name string, group_id string, dataset bool) ([]byte, error) { if err := db.checkDatabaseOperationPrerequisites(db_name, group_id); err != nil { return nil, err @@ -292,23 +305,23 @@ func (db *Mongodb) GetNextRecord(db_name string, group_id string) ([]byte, error } log_str := "got next pointer " + strconv.Itoa(curPointer.Value) + " for " + db_name + ", groupid: " + group_id logger.Debug(log_str) - return db.GetRecordByIDRow(db_name, curPointer.Value, true) + return db.GetRecordByIDRow(db_name, curPointer.Value, true, dataset) } -func (db *Mongodb) GetLastRecord(db_name string, group_id string) ([]byte, error) { +func (db *Mongodb) GetLastRecord(db_name string, group_id string, dataset bool) ([]byte, error) { if err := db.checkDatabaseOperationPrerequisites(db_name, group_id); err != nil { return nil, err } - max_ind, err := db.getMaxIndex(db_name) + max_ind, err := db.getMaxIndex(db_name, dataset) if err != nil { log_str := "error getting last pointer for " + db_name + ", groupid: " + group_id + ":" + err.Error() logger.Debug(log_str) return nil, err } - res, err := db.GetRecordByIDRow(db_name, max_ind, false) + res, err := db.GetRecordByIDRow(db_name, max_ind, false, dataset) db.setCounter(db_name, group_id, max_ind) @@ -394,15 +407,20 @@ func (db *Mongodb) queryImages(dbname string, query string) ([]byte, error) { } func (db *Mongodb) ProcessRequest(db_name string, group_id string, op string, extra_param string) (answer []byte, err error) { + dataset := false + if strings.HasSuffix(op, "_dataset") { + dataset = true + op = op[:len(op)-8] + } switch op { case "next": - return db.GetNextRecord(db_name, group_id) + return db.GetNextRecord(db_name, group_id, dataset) case "id": - return db.GetRecordByID(db_name, group_id, extra_param, true, false) + return db.GetRecordByID(db_name, group_id, extra_param, true, false, dataset) case "idreset": - return db.GetRecordByID(db_name, group_id, extra_param, true, true) + return db.GetRecordByID(db_name, group_id, extra_param, true, true, dataset) case "last": - return db.GetLastRecord(db_name, group_id) + return db.GetLastRecord(db_name, group_id, dataset) case "resetcounter": return db.ResetCounter(db_name, group_id) case "size": diff --git a/broker/src/asapo_broker/database/mongodb_test.go b/broker/src/asapo_broker/database/mongodb_test.go index 980a53fe8eb3ff66bd2eccbb38337e7dbbf15884..cdca011049bfe5b0e1444e23d989eccd6ad56219 100644 --- a/broker/src/asapo_broker/database/mongodb_test.go +++ b/broker/src/asapo_broker/database/mongodb_test.go @@ -17,6 +17,12 @@ type TestRecord struct { FName string `bson:"fname" json:"fname"` } +type TestDataset struct { + ID int `bson:"_id" json:"_id"` + Size int `bson:"size" json:"size"` + Images []TestRecord `bson:"images" json:"images"` +} + var db Mongodb const dbname = "run1" @@ -60,14 +66,14 @@ func TestMongoDBConnectOK(t *testing.T) { } func TestMongoDBGetNextErrorWhenNotConnected(t *testing.T) { - _, err := db.GetNextRecord("", groupId) + _, err := db.GetNextRecord("", groupId, false) assert.Equal(t, utils.StatusError, err.(*DBError).Code) } func TestMongoDBGetNextErrorWhenWrongDatabasename(t *testing.T) { db.Connect(dbaddress) defer cleanup() - _, err := db.GetNextRecord("", groupId) + _, err := db.GetNextRecord("", groupId, false) assert.Equal(t, utils.StatusWrongInput, err.(*DBError).Code) } @@ -75,7 +81,7 @@ func TestMongoDBGetNextErrorWhenEmptyCollection(t *testing.T) { db.Connect(dbaddress) db.databases = append(db.databases, dbname) defer cleanup() - _, err := db.GetNextRecord(dbname, groupId) + _, err := db.GetNextRecord(dbname, groupId, false) assert.Equal(t, utils.StatusNoData, err.(*DBError).Code) } @@ -83,7 +89,7 @@ func TestMongoDBGetNextErrorWhenRecordNotThereYet(t *testing.T) { db.Connect(dbaddress) defer cleanup() db.InsertRecord(dbname, &rec2) - _, err := db.GetNextRecord(dbname, groupId) + _, err := db.GetNextRecord(dbname, groupId, false) assert.Equal(t, utils.StatusNoData, err.(*DBError).Code) assert.Equal(t, "{\"id\":1}", err.Error()) } @@ -92,7 +98,7 @@ func TestMongoDBGetNextOK(t *testing.T) { db.Connect(dbaddress) defer cleanup() db.InsertRecord(dbname, &rec1) - res, err := db.GetNextRecord(dbname, groupId) + res, err := db.GetNextRecord(dbname, groupId, false) assert.Nil(t, err) assert.Equal(t, string(rec1_expect), string(res)) } @@ -101,8 +107,8 @@ func TestMongoDBGetNextErrorOnNoMoreData(t *testing.T) { db.Connect(dbaddress) defer cleanup() db.InsertRecord(dbname, &rec1) - db.GetNextRecord(dbname, groupId) - _, err := db.GetNextRecord(dbname, groupId) + db.GetNextRecord(dbname, groupId, false) + _, err := db.GetNextRecord(dbname, groupId, false) assert.Equal(t, utils.StatusNoData, err.(*DBError).Code) } @@ -111,8 +117,8 @@ func TestMongoDBGetNextCorrectOrder(t *testing.T) { defer cleanup() db.InsertRecord(dbname, &rec2) db.InsertRecord(dbname, &rec1) - res1, _ := db.GetNextRecord(dbname, groupId) - res2, _ := db.GetNextRecord(dbname, groupId) + res1, _ := db.GetNextRecord(dbname, groupId, false) + res2, _ := db.GetNextRecord(dbname, groupId, false) assert.Equal(t, string(rec1_expect), string(res1)) assert.Equal(t, string(rec2_expect), string(res2)) } @@ -144,7 +150,7 @@ func getRecords(n int) []int { for i := 0; i < n; i++ { go func() { defer wg.Done() - res_bin, _ := db.GetNextRecord(dbname, groupId) + res_bin, _ := db.GetNextRecord(dbname, groupId, false) var res TestRecord json.Unmarshal(res_bin, &res) results[res.ID] = 1 @@ -170,7 +176,7 @@ func TestMongoDBGetRecordByID(t *testing.T) { db.Connect(dbaddress) defer cleanup() db.InsertRecord(dbname, &rec1) - res, err := db.GetRecordByID(dbname, "", "1", true, false) + res, err := db.GetRecordByID(dbname, "", "1", true, false, false) assert.Nil(t, err) assert.Equal(t, string(rec1_expect), string(res)) } @@ -179,7 +185,7 @@ func TestMongoDBGetRecordByIDFails(t *testing.T) { db.Connect(dbaddress) defer cleanup() db.InsertRecord(dbname, &rec1) - _, err := db.GetRecordByID(dbname, "", "2", true, false) + _, err := db.GetRecordByID(dbname, "", "2", true, false, false) assert.Equal(t, utils.StatusNoData, err.(*DBError).Code) assert.Equal(t, "{\"id\":2}", err.Error()) } @@ -287,7 +293,7 @@ func TestMongoDBGetRecordIDWithReset(t *testing.T) { } func TestMongoDBGetRecordByIDNotConnected(t *testing.T) { - _, err := db.GetRecordByID(dbname, "", "2", true, false) + _, err := db.GetRecordByID(dbname, "", "2", true, false, false) assert.Equal(t, utils.StatusError, err.(*DBError).Code) } @@ -416,3 +422,104 @@ func TestMongoDBQueryImagesOK(t *testing.T) { } } + +var rec_dataset1 = TestDataset{1, 3, []TestRecord{rec1, rec2, rec3}} +var rec_dataset2 = TestDataset{2, 2, []TestRecord{rec1, rec2, rec3}} +var rec_dataset3 = TestDataset{3, 3, []TestRecord{rec3, rec2, rec2}} + +func TestMongoDBGetDataset(t *testing.T) { + db.Connect(dbaddress) + defer cleanup() + + db.InsertRecord(dbname, &rec_dataset1) + + res_string, err := db.ProcessRequest(dbname, groupId, "next_dataset", "0") + + assert.Nil(t, err) + + var res TestDataset + json.Unmarshal(res_string, &res) + + assert.Equal(t, rec_dataset1, res) +} + +func TestMongoDBNoDataOnNotCompletedDataset(t *testing.T) { + db.Connect(dbaddress) + defer cleanup() + + db.InsertRecord(dbname, &rec_dataset2) + + res_string, err := db.ProcessRequest(dbname, groupId, "next_dataset", "0") + + assert.Equal(t, utils.StatusNoData, err.(*DBError).Code) + assert.Equal(t, "", string(res_string)) +} + +func TestMongoDBGetRecordLastDataSetSkipsIncompleteSets(t *testing.T) { + db.Connect(dbaddress) + defer cleanup() + + db.InsertRecord(dbname, &rec_dataset1) + db.InsertRecord(dbname, &rec_dataset2) + + res_string, err := db.ProcessRequest(dbname, groupId, "last_dataset", "0") + + assert.Nil(t, err) + + var res TestDataset + json.Unmarshal(res_string, &res) + + assert.Equal(t, rec_dataset1, res) +} + +func TestMongoDBGetRecordLastDataSetOK(t *testing.T) { + db.Connect(dbaddress) + defer cleanup() + + db.InsertRecord(dbname, &rec_dataset1) + db.InsertRecord(dbname, &rec_dataset1) + + res_string, err := db.ProcessRequest(dbname, groupId, "last_dataset", "0") + + assert.Nil(t, err) + + var res TestDataset + json.Unmarshal(res_string, &res) + + assert.Equal(t, rec_dataset3, res) +} + +func TestMongoDBGetDatasetIDWithReset(t *testing.T) { + db.Connect(dbaddress) + defer cleanup() + db.InsertRecord(dbname, &rec_dataset1) + db.InsertRecord(dbname, &rec_dataset3) + + _, err1 := db.ProcessRequest(dbname, groupId, "idreset_dataset", "2") //error while record is not complete, but reset counter to 2 + res2s, err2 := db.ProcessRequest(dbname, groupId, "next_dataset", "0") // so getnext woudl get record number 3 + + assert.NotNil(t, err1) + assert.Nil(t, err2) + + var res2 TestDataset + json.Unmarshal(res2s, &res2) + + assert.Equal(t, rec_dataset3, res2) + +} + +func TestMongoDBGetDatasetID(t *testing.T) { + db.Connect(dbaddress) + defer cleanup() + db.InsertRecord(dbname, &rec_dataset1) + + res_string, err := db.ProcessRequest(dbname, groupId, "id_dataset", "1") + + assert.Nil(t, err) + + var res TestDataset + json.Unmarshal(res_string, &res) + + assert.Equal(t, rec_dataset1, res) + +} diff --git a/broker/src/asapo_broker/server/process_request.go b/broker/src/asapo_broker/server/process_request.go index 3cede6c0014d855064980565dce4ca584fefb5f1..552fb64f8594f775a4a955fb1b22a5056b06d0b3 100644 --- a/broker/src/asapo_broker/server/process_request.go +++ b/broker/src/asapo_broker/server/process_request.go @@ -59,6 +59,10 @@ func processRequest(w http.ResponseWriter, r *http.Request, op string, extra_par op = "idreset" } + if datasetRequested(r) { + op = op + "_dataset" + } + answer, code := processRequestInDb(db_name, group_id, op, extra_param) w.WriteHeader(code) w.Write(answer) diff --git a/broker/src/asapo_broker/server/process_request_test.go b/broker/src/asapo_broker/server/process_request_test.go index 727f2aa3e68d190874ecd206f90863887832d6ef..3b4b4604d84f3fed40874219dd35bb16765b789e 100644 --- a/broker/src/asapo_broker/server/process_request_test.go +++ b/broker/src/asapo_broker/server/process_request_test.go @@ -151,3 +151,11 @@ func (suite *ProcessRequestTestSuite) TestProcessRequestWrongGroupID() { w := doRequest("/database/" + expectedBeamtimeId + "/" + wrongGroupID + "/next" + correctTokenSuffix) suite.Equal(http.StatusBadRequest, w.Code, "wrong group id") } + +func (suite *ProcessRequestTestSuite) TestProcessRequestAddsDataset() { + suite.mock_db.On("ProcessRequest", expectedBeamtimeId, expectedGroupID, "next_dataset", "0").Return([]byte("Hello"), nil) + logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("processing request next_dataset in "+expectedBeamtimeId))) + ExpectCopyClose(suite.mock_db) + + doRequest("/database/" + expectedBeamtimeId + "/" + expectedGroupID + "/next" + correctTokenSuffix + "&dataset=true") +} diff --git a/broker/src/asapo_broker/server/request_common.go b/broker/src/asapo_broker/server/request_common.go index 53ae802ff4f4745da012cb23d425471ca7553f4b..8cda3b9967a66f06a5639f5b99f4d01a498a9be8 100644 --- a/broker/src/asapo_broker/server/request_common.go +++ b/broker/src/asapo_broker/server/request_common.go @@ -13,8 +13,8 @@ func writeAuthAnswer(w http.ResponseWriter, requestName string, db_name string, w.Write([]byte(err)) } -func resetRequested(r *http.Request) bool { - val := r.URL.Query().Get("reset") +func ValueTrue(r *http.Request, key string) bool { + val := r.URL.Query().Get(key) if len(val) == 0 { return false @@ -25,6 +25,15 @@ func resetRequested(r *http.Request) bool { } return false + +} + +func resetRequested(r *http.Request) bool { + return ValueTrue(r, "reset") +} + +func datasetRequested(r *http.Request) bool { + return ValueTrue(r, "dataset") } func testAuth(r *http.Request, beamtime_id string) error {