diff --git a/broker/src/asapo_broker/database/encoding.go b/broker/src/asapo_broker/database/encoding.go new file mode 100644 index 0000000000000000000000000000000000000000..7c61d5428ce59cbdf5d112f8f04e721f426c8585 --- /dev/null +++ b/broker/src/asapo_broker/database/encoding.go @@ -0,0 +1,80 @@ +package database + +import "net/url" + +func shouldEscape(c byte, db bool) bool { + if c == '$' || c == ' ' { + return true + } + if !db { + return false + } + + switch c { + case '\\', '/', '.', '"': + return true + } + return false +} + +const upperhex = "0123456789ABCDEF" + +func escape(s string, db bool) string { + hexCount := 0 + for i := 0; i < len(s); i++ { + c := s[i] + if shouldEscape(c, db) { + hexCount++ + } + } + + if hexCount == 0 { + return s + } + + var buf [64]byte + var t []byte + + required := len(s) + 2*hexCount + if required <= len(buf) { + t = buf[:required] + } else { + t = make([]byte, required) + } + + j := 0 + for i := 0; i < len(s); i++ { + switch c := s[i]; { + case shouldEscape(c, db): + t[j] = '%' + t[j+1] = upperhex[c>>4] + t[j+2] = upperhex[c&15] + j += 3 + default: + t[j] = s[i] + j++ + } + } + return string(t) +} + +func encodeStringForDbName(original string) (result string) { + return escape(original, true) +} + + +func decodeString(original string) (result string) { + result,_ = url.PathUnescape(original) + return result +} + +func encodeStringForColName(original string) (result string) { + return escape(original, false) +} + +func encodeRequest(request *Request) error { + request.DbName = encodeStringForDbName(request.DbName) + request.DbCollectionName = encodeStringForColName(request.DbCollectionName) + request.GroupId = encodeStringForColName(request.GroupId) + return nil +} diff --git a/broker/src/asapo_broker/database/mongodb.go b/broker/src/asapo_broker/database/mongodb.go index 5491086a02aebf133e75639a9498cecc956e9c31..75c885f30ee4290bf9d5980e2f23ebd25b4e9000 100644 --- a/broker/src/asapo_broker/database/mongodb.go +++ b/broker/src/asapo_broker/database/mongodb.go @@ -8,7 +8,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" @@ -67,7 +66,6 @@ const meta_collection_name = "meta" const pointer_collection_name = "current_location" const pointer_field_name = "current_pointer" const no_session_msg = "database client not created" -const wrong_id_type = "wrong id type" const already_connected_msg = "already connected" const finish_stream_keyword = "asapo_finish_stream" @@ -219,7 +217,7 @@ func (db *Mongodb) setCounter(request Request, ind int) (err error) { } func (db *Mongodb) errorWhenCannotIncrementField(request Request, max_ind int) (err error) { - if res, err := db.getRecordFromDb(request, max_ind, max_ind);err == nil { + if res, err := db.getRecordFromDb(request, max_ind, max_ind); err == nil { if err := checkStreamFinished(request, max_ind, max_ind, res); err != nil { return err } @@ -227,7 +225,6 @@ func (db *Mongodb) errorWhenCannotIncrementField(request Request, max_ind int) ( return &DBError{utils.StatusNoData, encodeAnswer(max_ind, max_ind, "")} } - func (db *Mongodb) incrementField(request Request, max_ind int, res interface{}) (err error) { update := bson.M{"$inc": bson.M{pointer_field_name: 1}} opts := options.FindOneAndUpdate().SetUpsert(true).SetReturnDocument(options.After) @@ -244,7 +241,7 @@ func (db *Mongodb) incrementField(request Request, max_ind int, res interface{}) if err2 := c.FindOneAndUpdate(context.TODO(), q, update, opts).Decode(res); err2 == nil { return nil } - return db.errorWhenCannotIncrementField(request,max_ind) + return db.errorWhenCannotIncrementField(request, max_ind) } return &DBError{utils.StatusTransactionInterrupted, err.Error()} } @@ -595,7 +592,6 @@ func checkStreamFinished(request Request, id, id_max int, data map[string]interf return nil } r, ok := ExtractMessageRecord(data) - fmt.Println(r,ok) if !ok || !r.FinishedStream { return nil } @@ -855,9 +851,17 @@ func (db *Mongodb) deleteDocumentsInCollection(request Request, collection strin return err } +func escapeQuery(query string )(res string) { + chars := `\-[]{}()*+?.,^$|#` + for _, char := range chars { + query = strings.ReplaceAll(query,string(char),`\`+string(char)) + } + return query +} + func (db *Mongodb) deleteCollectionsWithPrefix(request Request, prefix string) error { cols, err := db.client.Database(request.DbName).ListCollectionNames(context.TODO(), bson.M{"name": bson.D{ - {"$regex", primitive.Regex{Pattern: "^" + prefix, Options: "i"}}}}) + {"$regex", primitive.Regex{Pattern: "^" + escapeQuery(prefix), Options: "i"}}}}) if err != nil { return err } @@ -881,7 +885,7 @@ func (db *Mongodb) deleteServiceMeta(request Request) error { if err != nil { return err } - return db.deleteDocumentsInCollection(request, pointer_collection_name, "_id", ".*_"+request.DbCollectionName+"$") + return db.deleteDocumentsInCollection(request, pointer_collection_name, "_id", ".*_"+escapeQuery(request.DbCollectionName)+"$") } func (db *Mongodb) deleteStream(request Request) ([]byte, error) { @@ -1000,11 +1004,16 @@ func (db *Mongodb) getStreams(request Request) ([]byte, error) { return json.Marshal(&rec) } + func (db *Mongodb) ProcessRequest(request Request) (answer []byte, err error) { if err := db.checkDatabaseOperationPrerequisites(request); err != nil { return nil, err } + if err := encodeRequest(&request); err != nil { + return nil, err + } + switch request.Op { case "next": return db.getNextRecord(request) diff --git a/broker/src/asapo_broker/database/mongodb_streams.go b/broker/src/asapo_broker/database/mongodb_streams.go index 278ef3c57062196067b1d78c7814b0ecfcfba70e..243df816d1e182c2dd767e5dc0db106a57ee40d8 100644 --- a/broker/src/asapo_broker/database/mongodb_streams.go +++ b/broker/src/asapo_broker/database/mongodb_streams.go @@ -57,7 +57,8 @@ func readStreams(db *Mongodb, db_name string) (StreamsRecord, error) { var rec = StreamsRecord{[]StreamInfo{}} for _, coll := range result { if strings.HasPrefix(coll, data_collection_name_prefix) { - si := StreamInfo{Name: strings.TrimPrefix(coll, data_collection_name_prefix)} + sNameEncoded:= strings.TrimPrefix(coll, data_collection_name_prefix) + si := StreamInfo{Name: decodeString(sNameEncoded)} rec.Streams = append(rec.Streams, si) } } @@ -88,7 +89,7 @@ func findStreamAmongCurrent(currentStreams []StreamInfo, record StreamInfo) (int } func fillInfoFromEarliestRecord(db *Mongodb, db_name string, rec *StreamsRecord, record StreamInfo, i int) error { - res, err := db.getEarliestRawRecord(db_name, record.Name) + res, err := db.getEarliestRawRecord(db_name, encodeStringForColName(record.Name)) if err != nil { return err } @@ -102,7 +103,7 @@ func fillInfoFromEarliestRecord(db *Mongodb, db_name string, rec *StreamsRecord, } func fillInfoFromLastRecord(db *Mongodb, db_name string, rec *StreamsRecord, record StreamInfo, i int) error { - res, err := db.getLastRawRecord(db_name, record.Name) + res, err := db.getLastRawRecord(db_name, encodeStringForColName(record.Name)) if err != nil { return err } diff --git a/broker/src/asapo_broker/database/mongodb_test.go b/broker/src/asapo_broker/database/mongodb_test.go index 72471969075187ea4e5226bec89a167a63f26d07..9b4742ff9c0146509df402fe5295a946806b54c5 100644 --- a/broker/src/asapo_broker/database/mongodb_test.go +++ b/broker/src/asapo_broker/database/mongodb_test.go @@ -36,6 +36,11 @@ const groupId = "bid2a5auidddp1vl71d0" const metaID = 0 const metaID_str = "0" +const badSymbolsDb = `/\."$` +const badSymbolsCol = `$` +const badSymbolsDbEncoded = "%2F%5C%2E%22%24" +const badSymbolsColEncoded ="%24" + var empty_next = map[string]string{"next_stream": ""} var rec1 = TestRecord{1, empty_next, "aaa", 0} @@ -63,6 +68,15 @@ func cleanup() { db.Close() } +func cleanupWithName(name string) { + if db.client == nil { + return + } + db.dropDatabase(name) + db.Close() +} + + // these are the integration tests. They assume mongo db is runnig on 127.0.0.1:27027 // test names should contain MongoDB*** so that go test could find them: // go_integration_test(${TARGET_NAME}-connectdb "./..." "MongoDBConnect") @@ -882,6 +896,10 @@ var testsStreams = []struct { StreamsRecord{[]StreamInfo{StreamInfo{Name: "ss1", Timestamp: 0, LastId: 2, TimestampLast: 1}, StreamInfo{Name: "ss2", Timestamp: 1, LastId: 3, TimestampLast: 2}}}, "two streams", true}, {"ss2", []Stream{{"ss1", []TestRecord{rec1, rec2}}, {"ss2", []TestRecord{rec2, rec3}}}, StreamsRecord{[]StreamInfo{StreamInfo{Name: "ss2", Timestamp: 1, LastId: 3, TimestampLast: 2}}}, "with from", true}, + {"", []Stream{{"ss1$", []TestRecord{rec2, rec1}}}, + StreamsRecord{[]StreamInfo{StreamInfo{Name: "ss1$", Timestamp: 0, LastId: 2, TimestampLast: 1}}}, "one stream encoded", true}, + {"ss2$", []Stream{{"ss1$", []TestRecord{rec1, rec2}}, {"ss2$", []TestRecord{rec2, rec3}}}, StreamsRecord{[]StreamInfo{StreamInfo{Name: "ss2$", Timestamp: 1, LastId: 3, TimestampLast: 2}}}, "with from encoded", true}, + } func TestMongoDBListStreams(t *testing.T) { @@ -889,7 +907,7 @@ func TestMongoDBListStreams(t *testing.T) { db.Connect(dbaddress) for _, stream := range test.streams { for _, rec := range stream.records { - db.insertRecord(dbname, stream.name, &rec) + db.insertRecord(dbname, encodeStringForColName(stream.name), &rec) } } var rec_streams_expect, _ = json.Marshal(test.expectedStreams) @@ -1196,14 +1214,21 @@ var testsDeleteStream = []struct { }{ {"test", "{\"ErrorOnNotExist\":true,\"DeleteMeta\":true}", true,false, "delete stream"}, {"test", "{\"ErrorOnNotExist\":false,\"DeleteMeta\":true}", true, true,"delete stream"}, + {`test$/\ .%&?*#'`, "{\"ErrorOnNotExist\":false,\"DeleteMeta\":true}", true, true,"delete stream"}, + } func TestDeleteStreams(t *testing.T) { + defer cleanup() for _, test := range testsDeleteStream { db.Connect(dbaddress) - db.insertRecord(dbname, test.stream, &rec_finished11) - - _, err := db.ProcessRequest(Request{DbName: dbname, DbCollectionName: test.stream, GroupId: "", Op: "delete_stream", ExtraParam: test.params}) + db.insertRecord(dbname, encodeStringForColName(test.stream), &rec1) + db.ProcessRequest(Request{DbName: dbname, DbCollectionName: test.stream, GroupId: "123", Op: "next"}) + query_str := "{\"Id\":1,\"Op\":\"ackmessage\"}" + request := Request{DbName: dbname, DbCollectionName: test.stream, GroupId: groupId, Op: "ackmessage", ExtraParam: query_str} + _, err := db.ProcessRequest(request) + assert.Nil(t, err, test.message) + _, err = db.ProcessRequest(Request{DbName: dbname, DbCollectionName: test.stream, GroupId: "", Op: "delete_stream", ExtraParam: test.params}) if test.ok { rec, err := streams.getStreams(&db, Request{DbName: dbname, ExtraParam: ""}) acks_exist,_:= db.collectionExist(Request{DbName: dbname, ExtraParam: ""},acks_collection_name_prefix+test.stream) @@ -1223,3 +1248,36 @@ func TestDeleteStreams(t *testing.T) { } } } + + +var testsEncodings = []struct { + dbname string + collection string + group string + dbname_indb string + collection_indb string + group_indb string + message string + ok bool +}{ + {"dbname", "col", "group", "dbname","col","group", "no encoding",true}, + {"dbname"+badSymbolsDb, "col", "group", "dbname"+badSymbolsDbEncoded,"col","group", "symbols in db",true}, + {"dbname", "col"+badSymbolsCol, "group"+badSymbolsCol, "dbname","col"+badSymbolsColEncoded,"group"+badSymbolsColEncoded, "symbols in col",true}, + {"dbname"+badSymbolsDb, "col"+badSymbolsCol, "group"+badSymbolsCol, "dbname"+badSymbolsDbEncoded,"col"+badSymbolsColEncoded,"group"+badSymbolsColEncoded, "symbols in col and db",true}, + +} + +func TestMongoDBEncodingOK(t *testing.T) { + for _, test := range testsEncodings { + db.Connect(dbaddress) + db.insertRecord(test.dbname_indb, test.collection_indb, &rec1) + res, err := db.ProcessRequest(Request{DbName: test.dbname, DbCollectionName: test.collection, GroupId: test.group, Op: "next"}) + if test.ok { + assert.Nil(t, err, test.message) + assert.Equal(t, string(rec1_expect), string(res), test.message) + } else { + assert.Equal(t, utils.StatusWrongInput, err.(*DBError).Code, test.message) + } + cleanupWithName(test.dbname_indb) + } +} \ No newline at end of file diff --git a/broker/src/asapo_broker/server/get_commands_test.go b/broker/src/asapo_broker/server/get_commands_test.go index c472ddb4c3948004c412ea6267da70a582a72084..40c41c2b6bcb9ef70fd2d99567d7ab25caf1cb7d 100644 --- a/broker/src/asapo_broker/server/get_commands_test.go +++ b/broker/src/asapo_broker/server/get_commands_test.go @@ -6,6 +6,8 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "net/http" + "net/url" + "strings" "testing" ) @@ -34,23 +36,24 @@ func TestGetCommandsTestSuite(t *testing.T) { var testsGetCommand = []struct { command string + source string stream string groupid string reqString string queryParams string externalParam string }{ - {"last", expectedStream, "", expectedStream + "/0/last","","0"}, - {"id", expectedStream, "", expectedStream + "/0/1","","1"}, - {"meta", "default", "", "default/0/meta/0","","0"}, - {"nacks", expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/nacks","","0_0"}, - {"next", expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/next","",""}, - {"next", expectedStream, expectedGroupID, expectedStream + "/" + + {"last", expectedSource,expectedStream, "", expectedStream + "/0/last","","0"}, + {"id", expectedSource,expectedStream, "", expectedStream + "/0/1","","1"}, + {"meta", expectedSource,"default", "", "default/0/meta/0","","0"}, + {"nacks",expectedSource, expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/nacks","","0_0"}, + {"next", expectedSource,expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/next","",""}, + {"next", expectedSource,expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/next","&resend_nacks=true&delay_ms=10000&resend_attempts=3","10000_3"}, - {"size", expectedStream, "", expectedStream + "/size","",""}, - {"size", expectedStream, "", expectedStream + "/size","&incomplete=true","true"}, - {"streams", "0", "", "0/streams","","_"}, - {"lastack", expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/lastack","",""}, + {"size", expectedSource,expectedStream, "", expectedStream + "/size","",""}, + {"size",expectedSource, expectedStream, "", expectedStream + "/size","&incomplete=true","true"}, + {"streams",expectedSource, "0", "", "0/streams","","_"}, + {"lastack", expectedSource,expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/lastack","",""}, } @@ -58,8 +61,33 @@ func (suite *GetCommandsTestSuite) TestGetCommandsCallsCorrectRoutine() { for _, test := range testsGetCommand { suite.mock_db.On("ProcessRequest", database.Request{DbName: expectedDBName, DbCollectionName: test.stream, GroupId: test.groupid, Op: test.command, ExtraParam: test.externalParam}).Return([]byte("Hello"), nil) logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("processing request "+test.command))) - w := doRequest("/beamtime/" + expectedBeamtimeId + "/" + expectedSource + "/" + test.reqString+correctTokenSuffix+test.queryParams) + w := doRequest("/beamtime/" + expectedBeamtimeId + "/" + test.source + "/" + test.reqString+correctTokenSuffix+test.queryParams) suite.Equal(http.StatusOK, w.Code, test.command+ " OK") suite.Equal("Hello", string(w.Body.Bytes()), test.command+" sends data") } } + +func (suite *GetCommandsTestSuite) TestGetCommandsCorrectlyProcessedEncoding() { + badSymbols:="%$&./\\_$&\"" + for _, test := range testsGetCommand { + newstream := test.stream+badSymbols + newsource := test.source+badSymbols + newgroup :="" + if test.groupid!="" { + newgroup = test.groupid+badSymbols + } + encodedStream:=url.PathEscape(newstream) + encodedSource:=url.PathEscape(newsource) + encodedGroup:=url.PathEscape(newgroup) + test.reqString = strings.Replace(test.reqString,test.groupid,encodedGroup,1) + test.reqString = strings.Replace(test.reqString,test.source,encodedSource,1) + test.reqString = strings.Replace(test.reqString,test.stream,encodedStream,1) + dbname := expectedBeamtimeId + "_" + newsource + suite.mock_db.On("ProcessRequest", database.Request{DbName: dbname, DbCollectionName: newstream, GroupId: newgroup, Op: test.command, ExtraParam: test.externalParam}).Return([]byte("Hello"), nil) + logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("processing request "+test.command))) + w := doRequest("/beamtime/" + expectedBeamtimeId + "/" + encodedSource + "/" + test.reqString+correctTokenSuffix+test.queryParams) + suite.Equal(http.StatusOK, w.Code, test.command+ " OK") + suite.Equal("Hello", string(w.Body.Bytes()), test.command+" sends data") + } +} + diff --git a/broker/src/asapo_broker/server/process_request.go b/broker/src/asapo_broker/server/process_request.go index 7e001f3397feef9dc0384da6d38aa21078caa5e3..23fe151a8865dc7debecc7c770a561fba8eb803b 100644 --- a/broker/src/asapo_broker/server/process_request.go +++ b/broker/src/asapo_broker/server/process_request.go @@ -8,20 +8,33 @@ import ( "asapo_common/version" "github.com/gorilla/mux" "net/http" + "net/url" ) +func readFromMapUnescaped(key string, vars map[string]string) (val string,ok bool) { + if val, ok = vars[key];!ok { + return + } + var err error + if val, err = url.PathUnescape(val);err!=nil { + return "",false + } + return +} + func extractRequestParameters(r *http.Request, needGroupID bool) (string, string, string, string, bool) { vars := mux.Vars(r) db_name, ok1 := vars["beamtime"] - - datasource, ok3 := vars["datasource"] - stream, ok4 := vars["stream"] + datasource, ok3 := readFromMapUnescaped("datasource",vars) + stream, ok4 := readFromMapUnescaped("stream",vars) ok2 := true group_id := "" if needGroupID { - group_id, ok2 = vars["groupid"] + group_id, ok2 = readFromMapUnescaped("groupid",vars) } + + return db_name, datasource, stream, group_id, ok1 && ok2 && ok3 && ok4 } @@ -34,22 +47,6 @@ func IsLetterOrNumbers(s string) bool { return true } - -func checkGroupID(w http.ResponseWriter, needGroupID bool, group_id string, db_name string, op string) bool { - if !needGroupID { - return true - } - if len(group_id) > 0 && len (group_id) < 100 && IsLetterOrNumbers(group_id) { - return true - } - err_str := "wrong groupid name, check length or allowed charecters in " + group_id - log_str := "processing get " + op + " request in " + db_name + " at " + settings.GetDatabaseServer() + ": " + err_str - logger.Error(log_str) - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte(err_str)) - return false -} - func checkBrokerApiVersion(w http.ResponseWriter, r *http.Request) bool { _, ok := utils.PrecheckApiVersion(w, r, version.GetBrokerApiVersion()) return ok @@ -77,10 +74,6 @@ func processRequest(w http.ResponseWriter, r *http.Request, op string, extra_par return } - if !checkGroupID(w, needGroupID, group_id, db_name, op) { - return - } - request := database.Request{} request.DbName = db_name+"_"+datasource request.Op = op diff --git a/broker/src/asapo_broker/server/process_request_test.go b/broker/src/asapo_broker/server/process_request_test.go index 97248769155270c0bd33a7b194b25c2721169576..5204c1c540f3ca2535f90a78ead3ebb0b72c0da9 100644 --- a/broker/src/asapo_broker/server/process_request_test.go +++ b/broker/src/asapo_broker/server/process_request_test.go @@ -205,12 +205,6 @@ func (suite *ProcessRequestTestSuite) TestProcessRequestAddsCounter() { suite.Equal(1, statistics.GetCounter(), "ProcessRequest increases counter") } -func (suite *ProcessRequestTestSuite) TestProcessRequestWrongGroupID() { - logger.MockLog.On("Error", mock.MatchedBy(containsMatcher("wrong groupid"))) - w := doRequest("/beamtime/" + expectedBeamtimeId + "/" + expectedSource + "/" + expectedStream + "/" + wrongGroupID + "/next" + correctTokenSuffix) - suite.Equal(http.StatusBadRequest, w.Code, "wrong group id") -} - func (suite *ProcessRequestTestSuite) TestProcessRequestAddsDataset() { expectedRequest := database.Request{DbName: expectedDBName, DbCollectionName: expectedStream, GroupId: expectedGroupID, DatasetOp: true, Op: "next"} diff --git a/common/cpp/include/asapo/database/database.h b/common/cpp/include/asapo/database/database.h index 1fe87a7c6f5bc48c2317e816c2be34fdb3a9efcc..a26d1eacd2b4d05935d2428ef3500c0638d24115 100644 --- a/common/cpp/include/asapo/database/database.h +++ b/common/cpp/include/asapo/database/database.h @@ -11,7 +11,6 @@ namespace asapo { constexpr char kDBDataCollectionNamePrefix[] = "data"; constexpr char kDBMetaCollectionName[] = "meta"; - class Database { public: virtual Error Connect(const std::string& address, const std::string& database) = 0; diff --git a/common/cpp/include/asapo/database/db_error.h b/common/cpp/include/asapo/database/db_error.h index b02c60a7129988fadbd6494826a66bd10a94ee6a..6128dcbbd76c50f28a7d8c1c5430bb35345b6734 100644 --- a/common/cpp/include/asapo/database/db_error.h +++ b/common/cpp/include/asapo/database/db_error.h @@ -16,7 +16,8 @@ enum class DBErrorType { kAlreadyConnected, kBadAddress, kMemoryError, - kNoRecord + kNoRecord, + kWrongInput }; using DBError = ServiceError<DBErrorType, ErrorType::kDBError>; @@ -28,6 +29,10 @@ auto const kNoRecord = DBErrorTemplate { "No record", DBErrorType::kNoRecord }; +auto const kWrongInput = DBErrorTemplate { + "Wrong input", DBErrorType::kWrongInput +}; + auto const kNotConnected = DBErrorTemplate { "Not connected", DBErrorType::kNotConnected diff --git a/common/cpp/include/asapo/http_client/http_client.h b/common/cpp/include/asapo/http_client/http_client.h index 3a41ea96b28013c655b5d51ce4abcee6c80977a4..5733c68f52140f375b1a19579e81e819f0fd4e1e 100644 --- a/common/cpp/include/asapo/http_client/http_client.h +++ b/common/cpp/include/asapo/http_client/http_client.h @@ -21,6 +21,7 @@ class HttpClient { virtual Error Post(const std::string& uri, const std::string& cookie, const std::string& input_data, std::string output_file_name, HttpCode* response_code) const noexcept = 0; + virtual std::string UrlEscape(const std::string& uri) const noexcept = 0; virtual ~HttpClient() = default; }; diff --git a/common/cpp/include/asapo/unittests/MockHttpClient.h b/common/cpp/include/asapo/unittests/MockHttpClient.h index 41a5b5d232e2146c7851928ba503792cb959e8f1..1f47dff06f47775f2318e3a424c0acd0c5a10601 100644 --- a/common/cpp/include/asapo/unittests/MockHttpClient.h +++ b/common/cpp/include/asapo/unittests/MockHttpClient.h @@ -37,6 +37,13 @@ class MockHttpClient : public HttpClient { }; + std::string UrlEscape(const std::string& uri) const noexcept override { + return UrlEscape_t(uri); + } + + MOCK_CONST_METHOD1(UrlEscape_t,std::string(const std::string& uri)); + + MOCK_CONST_METHOD3(Get_t, std::string(const std::string& uri, HttpCode* code, ErrorInterface** err)); MOCK_CONST_METHOD5(Post_t, diff --git a/common/cpp/src/database/CMakeLists.txt b/common/cpp/src/database/CMakeLists.txt index 261134f12893116d9577afb3db0d0a1eadac4bb5..776cdc0555470e32a8df438bf6a3beb3d953e68f 100644 --- a/common/cpp/src/database/CMakeLists.txt +++ b/common/cpp/src/database/CMakeLists.txt @@ -1,6 +1,7 @@ set(TARGET_NAME database) set(SOURCE_FILES mongodb_client.cpp + encoding.cpp database.cpp) ################################ @@ -16,3 +17,15 @@ target_include_directories(${TARGET_NAME} PUBLIC ${ASAPO_CXX_COMMON_INCLUDE_DIR} PUBLIC "${MONGOC_STATIC_INCLUDE_DIRS}") target_link_libraries (${TARGET_NAME} PRIVATE "${MONGOC_STATIC_LIBRARIES}") target_compile_definitions (${TARGET_NAME} PRIVATE "${MONGOC_STATIC_DEFINITIONS}") + + +################################ +# Testing +################################ +set(TEST_SOURCE_FILES ../../unittests/database/test_encoding.cpp) + +set(TEST_LIBRARIES "${TARGET_NAME}") + +include_directories(${ASAPO_CXX_COMMON_INCLUDE_DIR}) +gtest(${TARGET_NAME} "${TEST_SOURCE_FILES}" "${TEST_LIBRARIES}") + diff --git a/common/cpp/src/database/encoding.cpp b/common/cpp/src/database/encoding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3112788ea2491dca99bb36ed23c0733ca84151e5 --- /dev/null +++ b/common/cpp/src/database/encoding.cpp @@ -0,0 +1,104 @@ +#include "encoding.h" +#include <regex> + +namespace asapo { + +bool ShouldEscape(char c, bool db) { + if (c == '$' || c == ' ') { + return true; + } + if (!db) { + return false; + } + + switch (c) { + case '/': + case '\\': + case '.': + case '"':return true; + } + return false; +} + +const std::string upperhex = "0123456789ABCDEF"; + +std::string Escape(const std::string &s, bool db) { + auto hexCount = 0; + for (auto i = 0; i < s.size(); i++) { + char c = s[i]; + if (ShouldEscape(c, db)) { + hexCount++; + } + } + + if (hexCount == 0) { + return s; + } + + char t[s.size() + 2 * hexCount + 1]; + t[s.size() + 2 * hexCount] = 0; + auto j = 0; + for (auto i = 0; i < s.size(); i++) { + auto c = s[i]; + if (ShouldEscape(c, db)) { + t[j] = '%'; + t[j + 1] = upperhex[c >> 4]; + t[j + 2] = upperhex[c & 15]; + j += 3; + + } else { + t[j] = c; + j++; + } + } + return t; +} + + +inline int ishex(int x) +{ + return (x >= '0' && x <= '9') || + (x >= 'a' && x <= 'f') || + (x >= 'A' && x <= 'F'); +} + +int decode(const char *s, char *dec) +{ + char *o; + const char *end = s + strlen(s); + int c; + + for (o = dec; s <= end; o++) { + c = *s++; +// if (c == '+') c = ' '; + if (c == '%' && ( !ishex(*s++) || + !ishex(*s++) || + !sscanf(s - 2, "%2x", &c))) + return -1; + if (dec) *o = c; + } + + return o - dec; +} + + +std::string EncodeDbName(const std::string &dbname) { + return Escape(dbname, true); +} + +std::string EncodeColName(const std::string &colname) { + return Escape(colname, false); +} + +std::string DecodeName(const std::string &name) { + char decoded[name.size()]; + auto res = decode(name.c_str(),decoded); + return res>=0?decoded:""; +} + +std::string EscapeQuery(const std::string& query) { + std::regex specialChars { R"([-[\]{}()*+?\\.,\^$|#\s])" }; + return std::regex_replace( query, specialChars, R"(\$&)" ); +} + +} diff --git a/common/cpp/src/database/encoding.h b/common/cpp/src/database/encoding.h new file mode 100644 index 0000000000000000000000000000000000000000..a0b5f47f64eaad4fc9d1e1c853bba4f7693dd460 --- /dev/null +++ b/common/cpp/src/database/encoding.h @@ -0,0 +1,14 @@ +#ifndef ASAPO_ENCODING_H +#define ASAPO_ENCODING_H + +#include <string> + +namespace asapo { + +std::string EncodeDbName(const std::string& dbname); +std::string EncodeColName(const std::string& colname); +std::string DecodeName(const std::string& name); +std::string EscapeQuery(const std::string& name); +} + +#endif //ASAPO_ENCODING_H diff --git a/common/cpp/src/database/mongodb_client.cpp b/common/cpp/src/database/mongodb_client.cpp index 6def7c2e3fccf55ade5f473a671522c2875d30d5..a08ae7da303db69adc886051f773cc5c5bea215c 100644 --- a/common/cpp/src/database/mongodb_client.cpp +++ b/common/cpp/src/database/mongodb_client.cpp @@ -1,7 +1,9 @@ #include "asapo/json_parser/json_parser.h" #include "mongodb_client.h" +#include "encoding.h" #include <chrono> +#include <regex> #include "asapo/database/db_error.h" #include "asapo/common/data_structs.h" @@ -53,18 +55,26 @@ Error MongoDBClient::InitializeClient(const std::string &address) { } -void MongoDBClient::UpdateCurrentCollectionIfNeeded(const std::string &collection_name) const { +Error MongoDBClient::UpdateCurrentCollectionIfNeeded(const std::string &collection_name) const { if (collection_name == current_collection_name_) { - return; + return nullptr; } if (current_collection_ != nullptr) { mongoc_collection_destroy(current_collection_); } + auto encoded_name = EncodeColName(collection_name); + if (encoded_name.size() > maxCollectionNameLength) { + return DBErrorTemplates::kWrongInput.Generate("stream name too long"); + } + current_collection_ = mongoc_client_get_collection(client_, database_name_.c_str(), - collection_name.c_str()); + encoded_name.c_str()); current_collection_name_ = collection_name; mongoc_collection_set_write_concern(current_collection_, write_concern_); + + return nullptr; + } Error MongoDBClient::TryConnectDatabase() { @@ -85,7 +95,11 @@ Error MongoDBClient::Connect(const std::string &address, const std::string &data return err; } - database_name_ = std::move(database_name); + database_name_ = EncodeDbName(database_name); + + if (database_name_.size() > maxDbNameLength) { + return DBErrorTemplates::kWrongInput.Generate("data source name too long"); + } err = TryConnectDatabase(); if (err) { @@ -177,9 +191,11 @@ Error MongoDBClient::Insert(const std::string &collection, const MessageMeta &fi return DBErrorTemplates::kNotConnected.Generate(); } - UpdateCurrentCollectionIfNeeded(collection); + auto err = UpdateCurrentCollectionIfNeeded(collection); + if (err) { + return err; + } - Error err; auto document = PrepareBsonDocument(file, &err); if (err) { return err; @@ -200,9 +216,11 @@ Error MongoDBClient::Upsert(const std::string &collection, uint64_t id, const ui return DBErrorTemplates::kNotConnected.Generate(); } - UpdateCurrentCollectionIfNeeded(collection); + auto err = UpdateCurrentCollectionIfNeeded(collection); + if (err) { + return err; + } - Error err; auto document = PrepareBsonDocument(data, (ssize_t) size, &err); if (err) { return err; @@ -244,9 +262,11 @@ Error MongoDBClient::InsertAsDatasetMessage(const std::string &collection, const return DBErrorTemplates::kNotConnected.Generate(); } - UpdateCurrentCollectionIfNeeded(collection); + auto err = UpdateCurrentCollectionIfNeeded(collection); + if (err) { + return err; + } - Error err; auto document = PrepareBsonDocument(file, &err); if (err) { return err; @@ -275,9 +295,11 @@ Error MongoDBClient::GetRecordFromDb(const std::string &collection, uint64_t id, return DBErrorTemplates::kNotConnected.Generate(); } - UpdateCurrentCollectionIfNeeded(collection); + auto err = UpdateCurrentCollectionIfNeeded(collection); + if (err) { + return err; + } - Error err; bson_error_t mongo_err; bson_t* filter; bson_t* opts; @@ -440,9 +462,9 @@ Error StreamInfoFromDbResponse(const std::string &last_record_str, } -Error MongoDBClient::GetStreamInfo(const std::string &collection, StreamInfo* info) const { +Error MongoDBClient::GetEncodedStreamInfo(const std::string &collection_encoded, StreamInfo* info) const { std::string last_record_str, earliest_record_str; - auto err = GetRecordFromDb(collection, 0, GetRecordMode::kLast, &last_record_str); + auto err = GetRecordFromDb(collection_encoded, 0, GetRecordMode::kLast, &last_record_str); if (err) { if (err == DBErrorTemplates::kNoRecord) { // with noRecord error it will return last_id = 0 which can be used to understand that the stream is not started yet @@ -451,7 +473,7 @@ Error MongoDBClient::GetStreamInfo(const std::string &collection, StreamInfo* in } return err; } - err = GetRecordFromDb(collection, 0, GetRecordMode::kEarliest, &earliest_record_str); + err = GetRecordFromDb(collection_encoded, 0, GetRecordMode::kEarliest, &earliest_record_str); if (err) { return err; } @@ -459,6 +481,11 @@ Error MongoDBClient::GetStreamInfo(const std::string &collection, StreamInfo* in return StreamInfoFromDbResponse(last_record_str, earliest_record_str, info); } +Error MongoDBClient::GetStreamInfo(const std::string &collection, StreamInfo* info) const { + std::string collection_encoded = EncodeColName(collection); + return GetEncodedStreamInfo(collection_encoded,info); +} + bool MongoCollectionIsDataStream(const std::string &stream_name) { std::string prefix = std::string(kDBDataCollectionNamePrefix) + "_"; return stream_name.rfind(prefix, 0) == 0; @@ -466,7 +493,7 @@ bool MongoCollectionIsDataStream(const std::string &stream_name) { Error MongoDBClient::UpdateCurrentLastStreamInfo(const std::string &collection_name, StreamInfo* info) const { StreamInfo next_info; - auto err = GetStreamInfo(collection_name, &next_info); + auto err = GetEncodedStreamInfo(collection_name, &next_info); std::string prefix = std::string(kDBDataCollectionNamePrefix) + "_"; if (err) { return err; @@ -518,6 +545,10 @@ Error MongoDBClient::GetLastStream(StreamInfo* info) const { err = DBErrorTemplates::kDBError.Generate(error.message); } + if (err!= nullptr) { + info->name = DecodeName(info->name); + } + bson_destroy(opts); mongoc_database_destroy(database); return err; @@ -527,7 +558,7 @@ Error MongoDBClient::DeleteCollections(const std::string &prefix) const { mongoc_database_t* database; char** strv; bson_error_t error; - std::string querystr = "^" + prefix; + std::string querystr = "^" + EscapeQuery(prefix); bson_t* query = BCON_NEW ("name", BCON_REGEX(querystr.c_str(), "i")); bson_t* opts = BCON_NEW ("nameOnly", BCON_BOOL(true), "filter", BCON_DOCUMENT(query)); database = mongoc_client_get_database(client_, database_name_.c_str()); @@ -556,7 +587,7 @@ Error MongoDBClient::DeleteCollection(const std::string &name) const { mongoc_collection_destroy(collection); if (!r) { if (error.code == 26) { - return DBErrorTemplates::kNoRecord.Generate("collection " + name + " not found in " + database_name_); + return DBErrorTemplates::kNoRecord.Generate("collection " + name + " not found in " + DecodeName(database_name_)); } else { return DBErrorTemplates::kDBError.Generate(std::string(error.message) + ": " + std::to_string(error.code)); } @@ -579,15 +610,16 @@ Error MongoDBClient::DeleteDocumentsInCollection(const std::string &collection_n } Error MongoDBClient::DeleteStream(const std::string &stream) const { - std::string data_col = std::string(kDBDataCollectionNamePrefix) + "_" + stream; - std::string inprocess_col = "inprocess_" + stream; - std::string acks_col = "acks_" + stream; + auto stream_encoded = EncodeColName(stream); + std::string data_col = std::string(kDBDataCollectionNamePrefix) + "_" + stream_encoded; + std::string inprocess_col = "inprocess_" + stream_encoded; + std::string acks_col = "acks_" + stream_encoded; current_collection_name_ = ""; auto err = DeleteCollection(data_col); if (err == nullptr) { DeleteCollections(inprocess_col); DeleteCollections(acks_col); - std::string querystr = ".*_" + stream + "$"; + std::string querystr = ".*_" + EscapeQuery(stream_encoded) + "$"; DeleteDocumentsInCollection("current_location", querystr); } return err; diff --git a/common/cpp/src/database/mongodb_client.h b/common/cpp/src/database/mongodb_client.h index 226c134b4d0e17d3ddab76b9fe4b30c31734d7db..d1be1bf21dc9064246dc817b4d6ccaaad08b0fc7 100644 --- a/common/cpp/src/database/mongodb_client.h +++ b/common/cpp/src/database/mongodb_client.h @@ -39,6 +39,9 @@ enum class GetRecordMode { kEarliest }; +const size_t maxDbNameLength = 63; +const size_t maxCollectionNameLength = 100; + class MongoDBClient final : public Database { public: MongoDBClient(); @@ -63,7 +66,7 @@ class MongoDBClient final : public Database { void CleanUp(); std::string DBAddress(const std::string& address) const; Error InitializeClient(const std::string& address); - void UpdateCurrentCollectionIfNeeded(const std::string& collection_name) const ; + Error UpdateCurrentCollectionIfNeeded(const std::string& collection_name) const ; Error Ping(); Error TryConnectDatabase(); Error InsertBsonDocument(const bson_p& document, bool ignore_duplicates) const; @@ -72,6 +75,7 @@ class MongoDBClient final : public Database { Error GetRecordFromDb(const std::string& collection, uint64_t id, GetRecordMode mode, std::string* res) const; Error UpdateLastStreamInfo(const char *str, StreamInfo* info) const; Error UpdateCurrentLastStreamInfo(const std::string& collection_name, StreamInfo* info) const; + Error GetEncodedStreamInfo(const std::string& collection, StreamInfo* info) const; Error DeleteCollection(const std::string& name) const; Error DeleteCollections(const std::string &prefix) const; Error DeleteDocumentsInCollection(const std::string &collection_name,const std::string &querystr) const; diff --git a/common/cpp/src/http_client/curl_http_client.cpp b/common/cpp/src/http_client/curl_http_client.cpp index 0bdefda2189197185f2c5e7797dc1b0f1a491ffa..aff8545ceabc856c144ea21e4ceef31f905c8f54 100644 --- a/common/cpp/src/http_client/curl_http_client.cpp +++ b/common/cpp/src/http_client/curl_http_client.cpp @@ -218,5 +218,18 @@ CurlHttpClient::~CurlHttpClient() { curl_easy_cleanup(curl_); } } +std::string CurlHttpClient::UrlEscape(const std::string &uri) const noexcept { + if (!curl_) { + return ""; + } + char *output = curl_easy_escape(curl_, uri.c_str(), uri.size()); + if (output) { + auto res = std::string(output); + curl_free(output); + return res; + } else { + return ""; + } +} } diff --git a/common/cpp/src/http_client/curl_http_client.h b/common/cpp/src/http_client/curl_http_client.h index cfc2e7626974422a72d87b371660cd7052cc34ea..50ffc1cef90df7c511ecc37f1352bab3287baead 100644 --- a/common/cpp/src/http_client/curl_http_client.h +++ b/common/cpp/src/http_client/curl_http_client.h @@ -34,6 +34,8 @@ class CurlHttpClient final : public HttpClient { std::string Get(const std::string& uri, HttpCode* response_code, Error* err) const noexcept override; std::string Post(const std::string& uri, const std::string& cookie, const std::string& data, HttpCode* response_code, Error* err) const noexcept override; + std::string UrlEscape(const std::string& uri) const noexcept override; + Error Post(const std::string& uri, const std::string& cookie, const std::string& input_data, MessageData* output_data, uint64_t output_data_size, HttpCode* response_code) const noexcept override; diff --git a/common/cpp/unittests/database/test_encoding.cpp b/common/cpp/unittests/database/test_encoding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8cd7ad3f1967f174f95f9731a9d83381a4684f8 --- /dev/null +++ b/common/cpp/unittests/database/test_encoding.cpp @@ -0,0 +1,50 @@ +#include "../../src/database/encoding.h" +#include "asapo/preprocessor/definitions.h" +#include <gmock/gmock.h> +#include "gtest/gtest.h" +#include <chrono> + +using ::testing::AtLeast; +using ::testing::Eq; +using ::testing::Ne; +using ::testing::Test; +using ::testing::_; +using ::testing::Mock; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::SetArgPointee; + + +namespace { + +uint64_t big_uint = 18446744073709551615ull; + +TEST(EncodingTests, EncodeDbName) { + + std::string dbname = R"(db_/\."$)"; + std::string dbname_encoded = "db_%2F%5C%2E%22%24"; + + auto encoded = asapo::EncodeDbName(dbname); + auto decoded = asapo::DecodeName(encoded); + + ASSERT_THAT(encoded, Eq(dbname_encoded)); + ASSERT_THAT(decoded, Eq(dbname)); +} + +TEST(EncodingTests, EncodeColName) { + + std::string colname = R"(col_/\."$)"; + std::string colname_encoded = R"(col_/\."%24)"; + + auto encoded = asapo::EncodeColName(colname); + + auto decoded = asapo::DecodeName(encoded); + + + ASSERT_THAT(encoded, Eq(colname_encoded)); + ASSERT_THAT(decoded, Eq(colname)); +} + + + +} diff --git a/common/go/src/asapo_common/utils/routes.go b/common/go/src/asapo_common/utils/routes.go index 598c443d92ce204243904ccd8bc9201b7ad94014..83b78f4180b6ac6e31137e3e7fcbb9c0f6f1845b 100644 --- a/common/go/src/asapo_common/utils/routes.go +++ b/common/go/src/asapo_common/utils/routes.go @@ -16,7 +16,7 @@ type Route struct { } func NewRouter(listRoutes Routes) *mux.Router { - router := mux.NewRouter() + router := mux.NewRouter().UseEncodedPath() for _, route := range listRoutes { router. Methods(route.Method). diff --git a/consumer/api/cpp/src/consumer_impl.cpp b/consumer/api/cpp/src/consumer_impl.cpp index f6e8d95ec11d865111f6de48e7b23534d13b777f..91898b3892bd390410648b40a22b614839519d69 100644 --- a/consumer/api/cpp/src/consumer_impl.cpp +++ b/consumer/api/cpp/src/consumer_impl.cpp @@ -131,7 +131,7 @@ ConsumerImpl::ConsumerImpl(std::string server_uri, if (source_credentials_.data_source.empty()) { source_credentials_.data_source = SourceCredentials::kDefaultDataSource; } - + data_source_encoded_ = httpclient__->UrlEscape(source_credentials_.data_source); } void ConsumerImpl::SetTimeout(uint64_t timeout_ms) { @@ -280,10 +280,8 @@ Error ConsumerImpl::GetRecordFromServer(std::string* response, std::string group interrupt_flag_ = false; std::string request_suffix = OpToUriCmd(op); std::string request_group = OpToUriCmd(op); - std::string - request_api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source - + "/" + std::move(stream); + + std::string request_api = UriPrefix(std::move(stream),"",""); uint64_t elapsed_ms = 0; Error no_data_error; while (true) { @@ -294,7 +292,7 @@ Error ConsumerImpl::GetRecordFromServer(std::string* response, std::string group auto start = system_clock::now(); auto err = DiscoverService(kBrokerServiceName, ¤t_broker_uri_); if (err == nullptr) { - auto ri = PrepareRequestInfo(request_api + "/" + group_id + "/" + request_suffix, dataset, min_size); + auto ri = PrepareRequestInfo(request_api + "/" + httpclient__->UrlEscape(group_id) + "/" + request_suffix, dataset, min_size); if (request_suffix == "next" && resend_) { ri.extra_params = ri.extra_params + "&resend_nacks=true" + "&delay_ms=" + std::to_string(delay_ms_) + "&resend_attempts=" + std::to_string(resend_attempts_); @@ -563,9 +561,8 @@ Error ConsumerImpl::ResetLastReadMarker(std::string group_id, std::string stream Error ConsumerImpl::SetLastReadMarker(std::string group_id, uint64_t value, std::string stream) { RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + "/" - + std::move(stream) + "/" + std::move(group_id) + "/resetcounter"; + ri.api = UriPrefix(std::move(stream),std::move(group_id),"resetcounter"); + ri.extra_params = "&value=" + std::to_string(value); ri.post = true; @@ -594,11 +591,9 @@ Error ConsumerImpl::GetRecordFromServerById(uint64_t id, std::string* response, } RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - +"/" + std::move(stream) + - "/" + std::move( - group_id) + "/" + std::to_string(id); + ri.api = UriPrefix(std::move(stream),std::move(group_id),std::to_string(id)); + + if (dataset) { ri.extra_params += "&dataset=true"; ri.extra_params += "&minsize=" + std::to_string(min_size); @@ -611,9 +606,7 @@ Error ConsumerImpl::GetRecordFromServerById(uint64_t id, std::string* response, std::string ConsumerImpl::GetBeamtimeMeta(Error* err) { RequestInfo ri; - ri.api = - "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + "/default/0/meta/0"; + ri.api = UriPrefix("default","0","meta/0"); return BrokerRequestWithTimeout(ri, err); } @@ -635,9 +628,8 @@ MessageMetas ConsumerImpl::QueryMessages(std::string query, std::string stream, } RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - "/" + std::move(stream) + "/0/querymessages"; + ri.api = UriPrefix(std::move(stream),"0","querymessages"); + ri.post = true; ri.body = std::move(query); @@ -731,8 +723,7 @@ StreamInfos ConsumerImpl::GetStreamList(std::string from, StreamFilter filter, E RequestInfo ConsumerImpl::GetStreamListRequest(const std::string &from, const StreamFilter &filter) const { RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + "/0/streams"; + ri.api = UriPrefix("0","","streams"); ri.post = false; if (!from.empty()) { ri.extra_params = "&from=" + from; @@ -800,10 +791,7 @@ Error ConsumerImpl::Acknowledge(std::string group_id, uint64_t id, std::string s return ConsumerErrorTemplates::kWrongInput.Generate("empty stream"); } RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - +"/" + std::move(stream) + - "/" + std::move(group_id) + "/" + std::to_string(id); + ri.api = UriPrefix(std::move(stream),std::move(group_id),std::to_string(id)); ri.post = true; ri.body = "{\"Op\":\"ackmessage\"}"; @@ -822,10 +810,7 @@ IdList ConsumerImpl::GetUnacknowledgedMessages(std::string group_id, return {}; } RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - +"/" + std::move(stream) + - "/" + std::move(group_id) + "/nacks"; + ri.api = UriPrefix(std::move(stream),std::move(group_id),"nacks"); ri.extra_params = "&from=" + std::to_string(from_id) + "&to=" + std::to_string(to_id); auto json_string = BrokerRequestWithTimeout(ri, error); @@ -848,10 +833,7 @@ uint64_t ConsumerImpl::GetLastAcknowledgedMessage(std::string group_id, std::str return 0; } RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - +"/" + std::move(stream) + - "/" + std::move(group_id) + "/lastack"; + ri.api = UriPrefix(std::move(stream),std::move(group_id),"lastack"); auto json_string = BrokerRequestWithTimeout(ri, error); if (*error) { @@ -884,10 +866,7 @@ Error ConsumerImpl::NegativeAcknowledge(std::string group_id, return ConsumerErrorTemplates::kWrongInput.Generate("empty stream"); } RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - +"/" + std::move(stream) + - "/" + std::move(group_id) + "/" + std::to_string(id); + ri.api = UriPrefix(std::move(stream),std::move(group_id),std::to_string(id)); ri.post = true; ri.body = R"({"Op":"negackmessage","Params":{"DelayMs":)" + std::to_string(delay_ms) + "}}"; @@ -929,9 +908,7 @@ uint64_t ConsumerImpl::ParseGetCurrentCountResponce(Error* err, const std::strin RequestInfo ConsumerImpl::GetSizeRequestForSingleMessagesStream(std::string &stream) const { RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - +"/" + std::move(stream) + "/size"; + ri.api = UriPrefix(std::move(stream),"","size"); return ri; } @@ -971,10 +948,7 @@ Error ConsumerImpl::GetVersionInfo(std::string* client_info, std::string* server RequestInfo ConsumerImpl::GetDeleteStreamRequest(std::string stream, DeleteStreamOptions options) const { RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - +"/" + std::move(stream) + - "/delete"; + ri.api = UriPrefix(std::move(stream),"","delete"); ri.post = true; ri.body = options.Json(); return ri; @@ -987,4 +961,20 @@ Error ConsumerImpl::DeleteStream(std::string stream, DeleteStreamOptions options return err; } +std::string ConsumerImpl::UriPrefix( std::string stream, std::string group, std::string suffix) const { + auto stream_encoded = httpclient__->UrlEscape(std::move(stream)); + auto group_encoded = group.size()>0?httpclient__->UrlEscape(std::move(group)):""; + auto uri = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" + + data_source_encoded_ + "/" + stream_encoded; + if (group_encoded.size()>0) { + uri = uri + "/" + group_encoded; + } + if (suffix.size()>0) { + uri = uri + "/" + suffix; + } + + return uri; + +} + } \ No newline at end of file diff --git a/consumer/api/cpp/src/consumer_impl.h b/consumer/api/cpp/src/consumer_impl.h index 7f7411c99e8984ea1b1c58ae7b9d747e79f3177d..be4acf6e79eacca4dab1935451ad5d45a8047235 100644 --- a/consumer/api/cpp/src/consumer_impl.h +++ b/consumer/api/cpp/src/consumer_impl.h @@ -148,6 +148,7 @@ class ConsumerImpl final : public asapo::Consumer { uint64_t GetCurrentCount(std::string stream, const RequestInfo& ri, Error* err); RequestInfo GetStreamListRequest(const std::string &from, const StreamFilter &filter) const; Error GetServerVersionInfo(std::string* server_info, bool* supported) ; + std::string UriPrefix( std::string stream, std::string group, std::string suffix) const; std::string endpoint_; std::string current_broker_uri_; @@ -155,6 +156,7 @@ class ConsumerImpl final : public asapo::Consumer { std::string source_path_; bool has_filesystem_; SourceCredentials source_credentials_; + std::string data_source_encoded_; uint64_t timeout_ms_ = 0; bool should_try_rdma_first_ = true; NetworkConnectionType current_connection_type_ = NetworkConnectionType::kUndefined; diff --git a/consumer/api/cpp/unittests/test_consumer_impl.cpp b/consumer/api/cpp/unittests/test_consumer_impl.cpp index 93511dee7837f231368c258b2c0039b680cb5227..ebb3eaa174b5f3ce13ec02013638c9f8f198fb5d 100644 --- a/consumer/api/cpp/unittests/test_consumer_impl.cpp +++ b/consumer/api/cpp/unittests/test_consumer_impl.cpp @@ -79,9 +79,13 @@ class ConsumerImplTests : public Test { std::string expected_path = "/tmp/beamline/beamtime"; std::string expected_filename = "filename"; std::string expected_full_path = std::string("/tmp/beamline/beamtime") + asapo::kPathSeparator + expected_filename; - std::string expected_group_id = "groupid"; - std::string expected_data_source = "source"; - std::string expected_stream = "stream"; + std::string expected_group_id = "groupid$"; + std::string expected_data_source = "source/$.?"; + std::string expected_stream = "str $ eam$"; + std::string expected_group_id_encoded = "groupid%24"; + std::string expected_data_source_encoded = "source%2F%24.%3F"; + std::string expected_stream_encoded = "str%20%24%20eam%24"; + std::string expected_metadata = "{\"meta\":1}"; std::string expected_query_string = "bla"; std::string expected_folder_token = "folder_token"; @@ -116,6 +120,13 @@ class ConsumerImplTests : public Test { fts_consumer->io__ = std::unique_ptr<IO>{&mock_io}; fts_consumer->httpclient__ = std::unique_ptr<asapo::HttpClient>{&mock_http_client}; fts_consumer->net_client__ = std::unique_ptr<asapo::NetClient>{&mock_netclient}; + ON_CALL(mock_http_client, UrlEscape_t(expected_stream)).WillByDefault(Return(expected_stream_encoded)); + ON_CALL(mock_http_client, UrlEscape_t(expected_group_id)).WillByDefault(Return(expected_group_id_encoded)); + ON_CALL(mock_http_client, UrlEscape_t(expected_data_source)).WillByDefault(Return(expected_data_source_encoded)); + ON_CALL(mock_http_client, UrlEscape_t("0")).WillByDefault(Return("0")); + ON_CALL(mock_http_client, UrlEscape_t("")).WillByDefault(Return("")); + ON_CALL(mock_http_client, UrlEscape_t("default")).WillByDefault(Return("default")); + ON_CALL(mock_http_client, UrlEscape_t("stream")).WillByDefault(Return("stream")); } void TearDown() override { @@ -214,7 +225,7 @@ TEST_F(ConsumerImplTests, DefaultStreamIsDetector) { MockGetBrokerUri(); EXPECT_CALL(mock_http_client, - Get_t(expected_broker_api + "/beamtime/beamtime_id/detector/stream/" + expected_group_id + Get_t(expected_broker_api + "/beamtime/beamtime_id/detector/stream/" + expected_group_id_encoded + "/next?token=" + expected_token, _, @@ -223,14 +234,14 @@ TEST_F(ConsumerImplTests, DefaultStreamIsDetector) { SetArgPointee<2>(nullptr), Return(""))); - consumer->GetNext(expected_group_id, &info, nullptr, expected_stream); + consumer->GetNext(expected_group_id, &info, nullptr, "stream"); } TEST_F(ConsumerImplTests, GetNextUsesCorrectUriWithStream) { MockGetBrokerUri(); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/" + expected_group_id + "/next?token=" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + expected_stream_encoded + "/" + expected_group_id_encoded + "/next?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), @@ -243,7 +254,7 @@ TEST_F(ConsumerImplTests, GetLastUsesCorrectUri) { MockGetBrokerUri(); EXPECT_CALL(mock_http_client, - Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/"+ expected_stream+"/0/last?token=" + Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/"+ expected_stream_encoded+"/0/last?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), @@ -426,7 +437,7 @@ TEST_F(ConsumerImplTests, GetMessageReturnsNoDataAfterTimeoutEvenIfOtherErrorOcc Return("{\"op\":\"get_record_by_id\",\"id\":" + std::to_string(expected_dataset_id) + ",\"id_max\":2,\"next_stream\":\"""\"}"))); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/0/" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/stream/0/" + std::to_string(expected_dataset_id) + "?token=" + expected_token, _, _)).Times(AtLeast(1)).WillRepeatedly(DoAll( SetArgPointee<1>(HttpCode::NotFound), @@ -434,7 +445,7 @@ TEST_F(ConsumerImplTests, GetMessageReturnsNoDataAfterTimeoutEvenIfOtherErrorOcc Return(""))); consumer->SetTimeout(300); - auto err = consumer->GetNext(expected_group_id, &info, nullptr, expected_stream); + auto err = consumer->GetNext(expected_group_id, &info, nullptr, "stream"); ASSERT_THAT(err, Eq(asapo::ConsumerErrorTemplates::kNoData)); } @@ -627,13 +638,13 @@ TEST_F(ConsumerImplTests, ResetCounterByDefaultUsesCorrectUri) { consumer->SetTimeout(100); EXPECT_CALL(mock_http_client, - Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/" + - expected_group_id + + Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/stream/" + + expected_group_id_encoded + "/resetcounter?token=" + expected_token + "&value=0", _, _, _, _)).WillOnce(DoAll( SetArgPointee<3>(HttpCode::OK), SetArgPointee<4>(nullptr), Return(""))); - auto err = consumer->ResetLastReadMarker(expected_group_id, expected_stream); + auto err = consumer->ResetLastReadMarker(expected_group_id, "stream"); ASSERT_THAT(err, Eq(nullptr)); } @@ -641,9 +652,9 @@ TEST_F(ConsumerImplTests, ResetCounterUsesCorrectUri) { MockGetBrokerUri(); consumer->SetTimeout(100); - EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/" + - expected_group_id + + EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + expected_stream_encoded + "/" + + expected_group_id_encoded + "/resetcounter?token=" + expected_token + "&value=10", _, _, _, _)).WillOnce(DoAll( SetArgPointee<3>(HttpCode::OK), SetArgPointee<4>(nullptr), @@ -656,8 +667,8 @@ TEST_F(ConsumerImplTests, GetCurrentSizeUsesCorrectUri) { MockGetBrokerUri(); consumer->SetTimeout(100); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/size?token=" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + expected_stream_encoded + "/size?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), SetArgPointee<2>(nullptr), @@ -672,8 +683,8 @@ TEST_F(ConsumerImplTests, GetCurrentSizeErrorOnWrongResponce) { MockGetBrokerUri(); consumer->SetTimeout(100); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + - "/"+expected_stream+"/size?token=" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + + "/"+expected_stream_encoded+"/size?token=" + expected_token, _, _)).WillRepeatedly(DoAll( SetArgPointee<1>(HttpCode::Unauthorized), SetArgPointee<2>(nullptr), @@ -688,14 +699,14 @@ TEST_F(ConsumerImplTests, GetNDataErrorOnWrongParse) { MockGetBrokerUri(); consumer->SetTimeout(100); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/stream/size?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), SetArgPointee<2>(nullptr), Return("{\"siz\":10}"))); asapo::Error err; - auto size = consumer->GetCurrentSize(expected_stream,&err); + auto size = consumer->GetCurrentSize("stream",&err); ASSERT_THAT(err, Ne(nullptr)); ASSERT_THAT(size, Eq(0)); } @@ -706,7 +717,7 @@ TEST_F(ConsumerImplTests, GetByIdUsesCorrectUri) { auto to_send = CreateFI(); auto json = to_send.Json(); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/0/" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/stream/0/" + std::to_string( expected_dataset_id) + "?token=" + expected_token, _, @@ -715,7 +726,7 @@ TEST_F(ConsumerImplTests, GetByIdUsesCorrectUri) { SetArgPointee<2>(nullptr), Return(json))); - auto err = consumer->GetById(expected_dataset_id, &info, nullptr, expected_stream); + auto err = consumer->GetById(expected_dataset_id, &info, nullptr, "stream"); ASSERT_THAT(err, Eq(nullptr)); ASSERT_THAT(info.name, Eq(to_send.name)); @@ -725,14 +736,14 @@ TEST_F(ConsumerImplTests, GetByIdTimeouts) { MockGetBrokerUri(); consumer->SetTimeout(10); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/0/" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/stream/0/" + std::to_string(expected_dataset_id) + "?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::Conflict), SetArgPointee<2>(nullptr), Return(""))); - auto err = consumer->GetById(expected_dataset_id, &info, nullptr, expected_stream); + auto err = consumer->GetById(expected_dataset_id, &info, nullptr, "stream"); ASSERT_THAT(err, Eq(asapo::ConsumerErrorTemplates::kNoData)); } @@ -741,14 +752,14 @@ TEST_F(ConsumerImplTests, GetByIdReturnsEndOfStream) { MockGetBrokerUri(); consumer->SetTimeout(10); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/0/" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/stream/0/" + std::to_string(expected_dataset_id) + "?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::Conflict), SetArgPointee<2>(nullptr), Return("{\"op\":\"get_record_by_id\",\"id\":1,\"id_max\":1,\"next_stream\":\"""\"}"))); - auto err = consumer->GetById(expected_dataset_id, &info, nullptr, expected_stream); + auto err = consumer->GetById(expected_dataset_id, &info, nullptr, "stream"); ASSERT_THAT(err, Eq(asapo::ConsumerErrorTemplates::kEndOfStream)); } @@ -757,14 +768,14 @@ TEST_F(ConsumerImplTests, GetByIdReturnsEndOfStreamWhenIdTooLarge) { MockGetBrokerUri(); consumer->SetTimeout(10); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/0/" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/stream/0/" + std::to_string(expected_dataset_id) + "?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::Conflict), SetArgPointee<2>(nullptr), Return("{\"op\":\"get_record_by_id\",\"id\":100,\"id_max\":1,\"next_stream\":\"""\"}"))); - auto err = consumer->GetById(expected_dataset_id, &info, nullptr, expected_stream); + auto err = consumer->GetById(expected_dataset_id, &info, nullptr, "stream"); ASSERT_THAT(err, Eq(asapo::ConsumerErrorTemplates::kEndOfStream)); } @@ -773,7 +784,7 @@ TEST_F(ConsumerImplTests, GetMetaDataOK) { MockGetBrokerUri(); consumer->SetTimeout(100); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/default/0/meta/0?token=" + expected_token, _, _)).WillOnce(DoAll( @@ -879,7 +890,7 @@ TEST_F(ConsumerImplTests, QueryMessagesReturnRecords) { auto responce_string = "[" + json1 + "," + json2 + "]"; EXPECT_CALL(mock_http_client, - Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/0" + + Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/stream/0" + "/querymessages?token=" + expected_token, _, expected_query_string, _, _)).WillOnce(DoAll( SetArgPointee<3>(HttpCode::OK), SetArgPointee<4>(nullptr), @@ -887,7 +898,7 @@ TEST_F(ConsumerImplTests, QueryMessagesReturnRecords) { consumer->SetTimeout(100); asapo::Error err; - auto messages = consumer->QueryMessages(expected_query_string, expected_stream, &err); + auto messages = consumer->QueryMessages(expected_query_string, "stream", &err); ASSERT_THAT(err, Eq(nullptr)); ASSERT_THAT(messages.size(), Eq(2)); @@ -899,15 +910,15 @@ TEST_F(ConsumerImplTests, QueryMessagesReturnRecords) { TEST_F(ConsumerImplTests, GetNextDatasetUsesCorrectUri) { MockGetBrokerUri(); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/" + - expected_group_id + "/next?token=" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/stream/" + + expected_group_id_encoded + "/next?token=" + expected_token + "&dataset=true&minsize=0", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), SetArgPointee<2>(nullptr), Return(""))); asapo::Error err; - consumer->GetNextDataset(expected_group_id, 0, expected_stream, &err); + consumer->GetNextDataset(expected_group_id, 0, "stream", &err); } TEST_F(ConsumerImplTests, GetNextErrorOnEmptyStream) { @@ -1034,8 +1045,8 @@ TEST_F(ConsumerImplTests, GetDataSetReturnsParseError) { TEST_F(ConsumerImplTests, GetLastDatasetUsesCorrectUri) { MockGetBrokerUri(); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/0/last?token=" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + expected_stream_encoded + "/0/last?token=" + expected_token + "&dataset=true&minsize=1", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), @@ -1048,7 +1059,7 @@ TEST_F(ConsumerImplTests, GetLastDatasetUsesCorrectUri) { TEST_F(ConsumerImplTests, GetDatasetByIdUsesCorrectUri) { MockGetBrokerUri(); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/0/" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/stream/0/" + std::to_string(expected_dataset_id) + "?token=" + expected_token + "&dataset=true" + "&minsize=0", _, _)).WillOnce(DoAll( @@ -1056,13 +1067,13 @@ TEST_F(ConsumerImplTests, GetDatasetByIdUsesCorrectUri) { SetArgPointee<2>(nullptr), Return(""))); asapo::Error err; - consumer->GetDatasetById(expected_dataset_id, 0, expected_stream, &err); + consumer->GetDatasetById(expected_dataset_id, 0, "stream", &err); } TEST_F(ConsumerImplTests, DeleteStreamUsesCorrectUri) { MockGetBrokerUri(); std::string expected_delete_stream_query_string = "{\"ErrorOnNotExist\":true,\"DeleteMeta\":true}"; - EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/"+expected_stream+"/delete" + EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/"+expected_stream_encoded+"/delete" + "?token=" + expected_token, _, expected_delete_stream_query_string, _, _)).WillOnce(DoAll( SetArgPointee<3>(HttpCode::OK), @@ -1084,7 +1095,7 @@ TEST_F(ConsumerImplTests, GetStreamListUsesCorrectUri) { std::string(R"({"streams":[{"lastId":123,"name":"test","timestampCreated":1000000,"timestampLast":1000,"finished":false,"nextStream":""},)")+ R"({"lastId":124,"name":"test1","timestampCreated":2000000,"timestampLast":2000,"finished":true,"nextStream":"next"}]})"; EXPECT_CALL(mock_http_client, - Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/0/streams" + Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/0/streams" + "?token=" + expected_token + "&from=stream_from&filter=all", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), @@ -1103,7 +1114,7 @@ TEST_F(ConsumerImplTests, GetStreamListUsesCorrectUri) { TEST_F(ConsumerImplTests, GetStreamListUsesCorrectUriWithoutFrom) { MockGetBrokerUri(); EXPECT_CALL(mock_http_client, - Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/0/streams" + Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/0/streams" + "?token=" + expected_token+"&filter=finished", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), @@ -1245,9 +1256,9 @@ TEST_F(ConsumerImplTests, GetMessageTriesToGetTokenAgainIfTransferFailed) { TEST_F(ConsumerImplTests, AcknowledgeUsesCorrectUri) { MockGetBrokerUri(); auto expected_acknowledge_command = "{\"Op\":\"ackmessage\"}"; - EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/" + - expected_group_id + EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + expected_stream_encoded + "/" + + expected_group_id_encoded + "/" + std::to_string(expected_dataset_id) + "?token=" + expected_token, _, expected_acknowledge_command, _, _)).WillOnce(DoAll( SetArgPointee<3>(HttpCode::OK), @@ -1261,9 +1272,9 @@ TEST_F(ConsumerImplTests, AcknowledgeUsesCorrectUri) { void ConsumerImplTests::ExpectIdList(bool error) { MockGetBrokerUri(); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/" + - expected_group_id + "/nacks?token=" + expected_token + "&from=1&to=0", _, _)).WillOnce(DoAll( + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + expected_stream_encoded + "/" + + expected_group_id_encoded + "/nacks?token=" + expected_token + "&from=1&to=0", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), SetArgPointee<2>(nullptr), Return(error ? "" : "{\"unacknowledged\":[1,2,3]}"))); @@ -1279,9 +1290,9 @@ TEST_F(ConsumerImplTests, GetUnAcknowledgedListReturnsIds) { } void ConsumerImplTests::ExpectLastAckId(bool empty_response) { - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/" + - expected_group_id + "/lastack?token=" + expected_token, _, _)).WillOnce(DoAll( + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + expected_stream_encoded + "/" + + expected_group_id_encoded + "/lastack?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), SetArgPointee<2>(nullptr), Return(empty_response ? "{\"lastAckId\":0}" : "{\"lastAckId\":1}"))); @@ -1317,8 +1328,8 @@ TEST_F(ConsumerImplTests, GetByIdErrorsForId0) { TEST_F(ConsumerImplTests, ResendNacks) { MockGetBrokerUri(); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/" - + expected_group_id + "/next?token=" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/stream/" + + expected_group_id_encoded + "/next?token=" + expected_token + "&resend_nacks=true&delay_ms=10000&resend_attempts=3", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), @@ -1326,15 +1337,15 @@ TEST_F(ConsumerImplTests, ResendNacks) { Return(""))); consumer->SetResendNacs(true, 10000, 3); - consumer->GetNext(expected_group_id, &info, nullptr, expected_stream); + consumer->GetNext(expected_group_id, &info, nullptr, "stream"); } TEST_F(ConsumerImplTests, NegativeAcknowledgeUsesCorrectUri) { MockGetBrokerUri(); auto expected_neg_acknowledge_command = R"({"Op":"negackmessage","Params":{"DelayMs":10000}})"; - EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/" + - expected_group_id + EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + expected_stream_encoded + "/" + + expected_group_id_encoded + "/" + std::to_string(expected_dataset_id) + "?token=" + expected_token, _, expected_neg_acknowledge_command, _, _)).WillOnce( DoAll( @@ -1377,8 +1388,8 @@ TEST_F(ConsumerImplTests, GetCurrentDataSetCounteUsesCorrectUri) { MockGetBrokerUri(); consumer->SetTimeout(100); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/size?token=" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + expected_stream_encoded + "/size?token=" + expected_token+"&incomplete=true", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), SetArgPointee<2>(nullptr), diff --git a/examples/pipeline/in_to_out_python/check_linux.sh b/examples/pipeline/in_to_out_python/check_linux.sh index 4a1a7438585878bf130886ffa9405174ef151ab6..04c4fef60afd82c449055f992e913efaa46f60c5 100644 --- a/examples/pipeline/in_to_out_python/check_linux.sh +++ b/examples/pipeline/in_to_out_python/check_linux.sh @@ -2,15 +2,15 @@ source_path=. beamtime_id=asapo_test -data_source_in=detector -data_source_out=data_source +data_source_in=detector/123 +data_source_out=data_source/12.4 timeout=15 timeout_producer=25 nthreads=4 -indatabase_name=${beamtime_id}_${data_source_in} -outdatabase_name=${beamtime_id}_${data_source_out} +indatabase_name=${beamtime_id}_detector%2F123 +outdatabase_name=${beamtime_id}_data_source%2F12%2E4 #asapo_test read token token=$ASAPO_TEST_RW_TOKEN @@ -33,7 +33,7 @@ Cleanup() { echo "db.dropDatabase()" | mongo ${outdatabase_name} rm -rf processed rm -rf ${receiver_root_folder} - rm -rf out +# rm -rf out } diff --git a/examples/pipeline/in_to_out_python/in_to_out.py b/examples/pipeline/in_to_out_python/in_to_out.py index 0e58c1b0a0daf249ca960344e3088b177de570dd..93e7328e040c8f6eceb1a37dcc555b8c6f6702c7 100644 --- a/examples/pipeline/in_to_out_python/in_to_out.py +++ b/examples/pipeline/in_to_out_python/in_to_out.py @@ -7,6 +7,8 @@ import threading lock = threading.Lock() +print (asapo_consumer.__version__) +print (asapo_producer.__version__) n_send = 0 diff --git a/receiver/CMakeLists.txt b/receiver/CMakeLists.txt index 0bc661d5f7eea49a3926a5bc71641d0adb24a4b0..4a18b189606b27a371224d9a12dead194b711ed8 100644 --- a/receiver/CMakeLists.txt +++ b/receiver/CMakeLists.txt @@ -21,7 +21,7 @@ set(RECEIVER_CORE_FILES src/request_handler/request_handler_db_last_stream.cpp src/request_handler/request_handler_receive_metadata.cpp src/request_handler/request_handler_db_check_request.cpp - src/request_handler/request_handler_delete_stream.cpp + src/request_handler/request_handler_db_delete_stream.cpp src/request_handler/request_factory.cpp src/request_handler/request_handler_db.cpp src/file_processors/write_file_processor.cpp diff --git a/receiver/src/request_handler/request_factory.h b/receiver/src/request_handler/request_factory.h index 374c586ab0e7609effa52cfe12017b471561734a..ee371d5aca70e23200079f9a6a3ce5072622cb34 100644 --- a/receiver/src/request_handler/request_factory.h +++ b/receiver/src/request_handler/request_factory.h @@ -6,7 +6,7 @@ #include "../file_processors/receive_file_processor.h" #include "request_handler_db_stream_info.h" #include "request_handler_db_last_stream.h" -#include "request_handler_delete_stream.h" +#include "request_handler_db_delete_stream.h" namespace asapo { @@ -26,7 +26,7 @@ class RequestFactory { RequestHandlerReceiveMetaData request_handler_receive_metadata_; RequestHandlerDbWrite request_handler_dbwrite_{kDBDataCollectionNamePrefix}; RequestHandlerDbStreamInfo request_handler_db_stream_info_{kDBDataCollectionNamePrefix}; - RequestHandlerDeleteStream request_handler_delete_stream_{kDBDataCollectionNamePrefix}; + RequestHandlerDbDeleteStream request_handler_delete_stream_{kDBDataCollectionNamePrefix}; RequestHandlerDbLastStream request_handler_db_last_stream_{kDBDataCollectionNamePrefix}; RequestHandlerDbMetaWrite request_handler_db_meta_write_{kDBMetaCollectionName}; RequestHandlerAuthorize request_handler_authorize_; diff --git a/receiver/src/request_handler/request_handler_db.cpp b/receiver/src/request_handler/request_handler_db.cpp index 821f0d770a551e27e2a1c0737e4902ce46e5519e..f991780b8e60484e567a01db4a4c29a5271229c0 100644 --- a/receiver/src/request_handler/request_handler_db.cpp +++ b/receiver/src/request_handler/request_handler_db.cpp @@ -2,6 +2,7 @@ #include "../receiver_config.h" #include "../receiver_logger.h" #include "../request.h" +#include "asapo/database/db_error.h" namespace asapo { @@ -12,13 +13,13 @@ Error RequestHandlerDb::ProcessRequest(Request* request) const { db_name_ += "_" + data_source; } - return ConnectToDbIfNeeded(); } -RequestHandlerDb::RequestHandlerDb(std::string collection_name_prefix): log__{GetDefaultReceiverLogger()}, - http_client__{DefaultHttpClient()}, - collection_name_prefix_{std::move(collection_name_prefix)} { +RequestHandlerDb::RequestHandlerDb(std::string collection_name_prefix) : log__{GetDefaultReceiverLogger()}, + http_client__{DefaultHttpClient()}, + collection_name_prefix_{ + std::move(collection_name_prefix)} { DatabaseFactory factory; Error err; db_client__ = factory.Create(&err); @@ -28,7 +29,6 @@ StatisticEntity RequestHandlerDb::GetStatisticEntity() const { return StatisticEntity::kDatabase; } - Error RequestHandlerDb::GetDatabaseServerUri(std::string* uri) const { if (GetReceiverConfig()->database_uri != "auto") { *uri = GetReceiverConfig()->database_uri; @@ -39,15 +39,17 @@ Error RequestHandlerDb::GetDatabaseServerUri(std::string* uri) const { Error http_err; *uri = http_client__->Get(GetReceiverConfig()->discovery_server + "/asapo-mongodb", &code, &http_err); if (http_err) { - log__->Error(std::string{"http error when discover database server "} + " from " + GetReceiverConfig()->discovery_server - + " : " + http_err->Explain()); + log__->Error( + std::string{"http error when discover database server "} + " from " + GetReceiverConfig()->discovery_server + + " : " + http_err->Explain()); return ReceiverErrorTemplates::kInternalServerError.Generate("http error when discover database server" + - http_err->Explain()); + http_err->Explain()); } if (code != HttpCode::OK) { - log__->Error(std::string{"http error when discover database server "} + " from " + GetReceiverConfig()->discovery_server - + " : http code" + std::to_string((int)code)); + log__->Error( + std::string{"http error when discover database server "} + " from " + GetReceiverConfig()->discovery_server + + " : http code" + std::to_string((int) code)); return ReceiverErrorTemplates::kInternalServerError.Generate("error when discover database server"); } @@ -56,7 +58,6 @@ Error RequestHandlerDb::GetDatabaseServerUri(std::string* uri) const { return nullptr; } - Error RequestHandlerDb::ConnectToDbIfNeeded() const { if (!connected_to_db) { std::string uri; @@ -66,12 +67,24 @@ Error RequestHandlerDb::ConnectToDbIfNeeded() const { } err = db_client__->Connect(uri, db_name_); if (err) { - return ReceiverErrorTemplates::kInternalServerError.Generate("error connecting to database " + err->Explain()); + return DBErrorToReceiverError(err); } connected_to_db = true; } return nullptr; } +Error RequestHandlerDb::DBErrorToReceiverError(const Error &err) const { + if (err == nullptr) { + return nullptr; + } + std::string msg = "database error: " + err->Explain(); + if (err == DBErrorTemplates::kWrongInput || err == DBErrorTemplates::kNoRecord + || err == DBErrorTemplates::kJsonParseError) { + return ReceiverErrorTemplates::kBadRequest.Generate(msg); + } + + return ReceiverErrorTemplates::kInternalServerError.Generate(msg); +} } diff --git a/receiver/src/request_handler/request_handler_db.h b/receiver/src/request_handler/request_handler_db.h index 0ea67d20538c2a42eb7ca61872a0840e9faca457..0c006e6e3eda365143ccae52baaa71ecca77a783 100644 --- a/receiver/src/request_handler/request_handler_db.h +++ b/receiver/src/request_handler/request_handler_db.h @@ -21,6 +21,7 @@ class RequestHandlerDb : public ReceiverRequestHandler { std::unique_ptr<HttpClient> http_client__; protected: Error ConnectToDbIfNeeded() const; + Error DBErrorToReceiverError(const Error& err) const; mutable bool connected_to_db = false; mutable std::string db_name_; std::string collection_name_prefix_; diff --git a/receiver/src/request_handler/request_handler_db_check_request.cpp b/receiver/src/request_handler/request_handler_db_check_request.cpp index b43347fee946e7bc5a48cbb31296afcb65487271..3d33e1441c1eb2394ab4bcb75c1173cef5f11f0d 100644 --- a/receiver/src/request_handler/request_handler_db_check_request.cpp +++ b/receiver/src/request_handler/request_handler_db_check_request.cpp @@ -58,7 +58,7 @@ Error RequestHandlerDbCheckRequest::ProcessRequest(Request* request) const { MessageMeta record; auto err = GetRecordFromDb(request, &record); if (err) { - return err == DBErrorTemplates::kNoRecord ? nullptr : std::move(err); + return DBErrorToReceiverError(err == DBErrorTemplates::kNoRecord ? nullptr : std::move(err)); } if (SameRequestInRecord(request, record)) { diff --git a/receiver/src/request_handler/request_handler_delete_stream.cpp b/receiver/src/request_handler/request_handler_db_delete_stream.cpp similarity index 78% rename from receiver/src/request_handler/request_handler_delete_stream.cpp rename to receiver/src/request_handler/request_handler_db_delete_stream.cpp index 42719ff49899275d3bd364e322d13e8f0172c482..8b8dfa357519597f56ea56c595f66b5500229b19 100644 --- a/receiver/src/request_handler/request_handler_delete_stream.cpp +++ b/receiver/src/request_handler/request_handler_db_delete_stream.cpp @@ -1,14 +1,14 @@ -#include "request_handler_delete_stream.h" +#include "request_handler_db_delete_stream.h" #include "../receiver_config.h" #include <asapo/database/db_error.h> namespace asapo { -RequestHandlerDeleteStream::RequestHandlerDeleteStream(std::string collection_name_prefix) : RequestHandlerDb( +RequestHandlerDbDeleteStream::RequestHandlerDbDeleteStream(std::string collection_name_prefix) : RequestHandlerDb( std::move(collection_name_prefix)) { } -Error RequestHandlerDeleteStream::ProcessRequest(Request* request) const { +Error RequestHandlerDbDeleteStream::ProcessRequest(Request* request) const { if (auto err = RequestHandlerDb::ProcessRequest(request) ) { return err; } @@ -36,7 +36,7 @@ Error RequestHandlerDeleteStream::ProcessRequest(Request* request) const { return nullptr; } - return err; + return DBErrorToReceiverError(err); } diff --git a/receiver/src/request_handler/request_handler_db_delete_stream.h b/receiver/src/request_handler/request_handler_db_delete_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..8c59271650563b59f10df428e18b270a7da5c590 --- /dev/null +++ b/receiver/src/request_handler/request_handler_db_delete_stream.h @@ -0,0 +1,18 @@ +#ifndef ASAPO_REQUEST_HANDLER_DB_DELETE_STREAM_H +#define ASAPO_REQUEST_HANDLER_DB_DELETE_STREAM_H + +#include "request_handler_db.h" +#include "../request.h" + +namespace asapo { + +class RequestHandlerDbDeleteStream final: public RequestHandlerDb { + public: + RequestHandlerDbDeleteStream(std::string collection_name_prefix); + Error ProcessRequest(Request* request) const override; +}; + +} + + +#endif //ASAPO_REQUEST_HANDLER_DB_DELETE_STREAM_H diff --git a/receiver/src/request_handler/request_handler_db_last_stream.cpp b/receiver/src/request_handler/request_handler_db_last_stream.cpp index 7e31468f565b0306b0334cc18420b43ba797f6d5..c41c49c6b8d5073f98fd71e139b764bb38fc11f9 100644 --- a/receiver/src/request_handler/request_handler_db_last_stream.cpp +++ b/receiver/src/request_handler/request_handler_db_last_stream.cpp @@ -22,7 +22,7 @@ Error RequestHandlerDbLastStream::ProcessRequest(Request* request) const { db_name_ + " at " + GetReceiverConfig()->database_uri); request->SetResponseMessage(info.Json(), ResponseMessageType::kInfo); } - return err; + return DBErrorToReceiverError(err); } } \ No newline at end of file diff --git a/receiver/src/request_handler/request_handler_db_meta_write.cpp b/receiver/src/request_handler/request_handler_db_meta_write.cpp index 57fb21051f846e708c858bb8396e5e968b44e867..3f1c89f0ff175626a3ad1a3cda47ece8ee32daf2 100644 --- a/receiver/src/request_handler/request_handler_db_meta_write.cpp +++ b/receiver/src/request_handler/request_handler_db_meta_write.cpp @@ -22,7 +22,7 @@ Error RequestHandlerDbMetaWrite::ProcessRequest(Request* request) const { db_name_ + " at " + GetReceiverConfig()->database_uri); } - return err; + return DBErrorToReceiverError(err); } RequestHandlerDbMetaWrite::RequestHandlerDbMetaWrite(std::string collection_name) : RequestHandlerDb(std::move( collection_name)) { diff --git a/receiver/src/request_handler/request_handler_db_stream_info.cpp b/receiver/src/request_handler/request_handler_db_stream_info.cpp index 65d194ccfa1f570fa51341d58e6e3b799a50528c..ff9a8ab935c208a986b41038a6be1ccd03e65c32 100644 --- a/receiver/src/request_handler/request_handler_db_stream_info.cpp +++ b/receiver/src/request_handler/request_handler_db_stream_info.cpp @@ -23,7 +23,7 @@ Error RequestHandlerDbStreamInfo::ProcessRequest(Request* request) const { info.name = request->GetStream(); request->SetResponseMessage(info.Json(), ResponseMessageType::kInfo); } - return err; + return DBErrorToReceiverError(err); } } \ No newline at end of file diff --git a/receiver/src/request_handler/request_handler_db_write.cpp b/receiver/src/request_handler/request_handler_db_write.cpp index d0113286aab69657d27e4ed7c69fdd272b411855..db28504e4f66991a18f257362d8e6e76ade59795 100644 --- a/receiver/src/request_handler/request_handler_db_write.cpp +++ b/receiver/src/request_handler/request_handler_db_write.cpp @@ -34,7 +34,7 @@ Error RequestHandlerDbWrite::ProcessRequest(Request* request) const { if (err == DBErrorTemplates::kDuplicateID) { return ProcessDuplicateRecordSituation(request); } else { - return err; + return DBErrorToReceiverError(err); } } diff --git a/receiver/src/request_handler/request_handler_delete_stream.h b/receiver/src/request_handler/request_handler_delete_stream.h deleted file mode 100644 index 3cf4e0fb0a46dd399c0e4dc4f23e2087c5f96790..0000000000000000000000000000000000000000 --- a/receiver/src/request_handler/request_handler_delete_stream.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef ASAPO_REQUEST_HANDLER_DELETE_STREAM_H -#define ASAPO_REQUEST_HANDLER_DELETE_STREAM_H - -#include "request_handler_db.h" -#include "../request.h" - -namespace asapo { - -class RequestHandlerDeleteStream final: public RequestHandlerDb { - public: - RequestHandlerDeleteStream(std::string collection_name_prefix); - Error ProcessRequest(Request* request) const override; -}; - -} - - -#endif //ASAPO_REQUEST_HANDLER_DELETE_STREAM_H diff --git a/receiver/src/request_handler/requests_dispatcher.cpp b/receiver/src/request_handler/requests_dispatcher.cpp index 79d414af2535316d025c13d9c4cdca127ddb05b9..ad009aba5addd6dbc66bec875df4900554c3905f 100644 --- a/receiver/src/request_handler/requests_dispatcher.cpp +++ b/receiver/src/request_handler/requests_dispatcher.cpp @@ -22,7 +22,7 @@ NetworkErrorCode GetNetworkCodeFromError(const Error& err) { return NetworkErrorCode::kNetErrorNotSupported; } else if (err == ReceiverErrorTemplates::kReAuthorizationFailure) { return NetworkErrorCode::kNetErrorReauthorize; - } else if (err == DBErrorTemplates::kJsonParseError || err == ReceiverErrorTemplates::kBadRequest || err == DBErrorTemplates::kNoRecord) { + } else if (err == ReceiverErrorTemplates::kBadRequest) { return NetworkErrorCode::kNetErrorWrongRequest; } else { return NetworkErrorCode::kNetErrorInternalServerError; diff --git a/receiver/unittests/request_handler/test_request_factory.cpp b/receiver/unittests/request_handler/test_request_factory.cpp index b9b8418e0d3591acf378b2f2704f54678789a1b3..363cd3cdcf115b88f793cf87b4977cfda16f93aa 100644 --- a/receiver/unittests/request_handler/test_request_factory.cpp +++ b/receiver/unittests/request_handler/test_request_factory.cpp @@ -15,7 +15,7 @@ #include "../../src/request_handler/request_handler_authorize.h" #include "../../src/request_handler/request_handler_db_stream_info.h" #include "../../src/request_handler/request_handler_db_last_stream.h" -#include "../../src/request_handler/request_handler_delete_stream.h" +#include "../../src/request_handler/request_handler_db_delete_stream.h" #include "../../src/request_handler/request_handler_receive_data.h" #include "../../src/request_handler/request_handler_receive_metadata.h" @@ -220,7 +220,7 @@ TEST_F(FactoryTests, DeleteStreamRequest) { ASSERT_THAT(err, Eq(nullptr)); ASSERT_THAT(request->GetListHandlers().size(), Eq(2)); ASSERT_THAT(dynamic_cast<const asapo::RequestHandlerAuthorize*>(request->GetListHandlers()[0]), Ne(nullptr)); - ASSERT_THAT(dynamic_cast<const asapo::RequestHandlerDeleteStream*>(request->GetListHandlers()[1]), Ne(nullptr)); + ASSERT_THAT(dynamic_cast<const asapo::RequestHandlerDbDeleteStream*>(request->GetListHandlers()[1]), Ne(nullptr)); } diff --git a/receiver/unittests/request_handler/test_request_handler_db_check_request.cpp b/receiver/unittests/request_handler/test_request_handler_db_check_request.cpp index 85c2b5da91967daf92d26146e08bd6eef6030e4f..05d43fca83158008680a543895734eb8dcf7e9f9 100644 --- a/receiver/unittests/request_handler/test_request_handler_db_check_request.cpp +++ b/receiver/unittests/request_handler/test_request_handler_db_check_request.cpp @@ -261,7 +261,7 @@ TEST_F(DbCheckRequestHandlerTests, ErrorIfDbError) { for (auto mock : mock_functions) { mock(asapo::DBErrorTemplates::kConnectionError.Generate().release(), false); auto err = handler.ProcessRequest(mock_request.get()); - ASSERT_THAT(err, Eq(asapo::DBErrorTemplates::kConnectionError)); + ASSERT_THAT(err, Eq(asapo::ReceiverErrorTemplates::kInternalServerError)); Mock::VerifyAndClearExpectations(mock_request.get()); } } diff --git a/receiver/unittests/request_handler/test_request_handler_delete_stream.cpp b/receiver/unittests/request_handler/test_request_handler_delete_stream.cpp index 277ccdea4003e9dfafa819fc57886c423a772654..1f7286d3346660f3251d9ff61c38bc33ae3014be 100644 --- a/receiver/unittests/request_handler/test_request_handler_delete_stream.cpp +++ b/receiver/unittests/request_handler/test_request_handler_delete_stream.cpp @@ -10,7 +10,7 @@ #include "../../src/request.h" #include "../../src/request_handler/request_factory.h" #include "../../src/request_handler/request_handler.h" -#include "../../src/request_handler/request_handler_delete_stream.h" +#include "../../src/request_handler/request_handler_db_delete_stream.h" #include "../../../common/cpp/src/database/mongodb_client.h" #include "../mock_receiver_config.h" @@ -43,7 +43,7 @@ using ::asapo::FileDescriptor; using ::asapo::SocketDescriptor; using ::asapo::MockIO; using asapo::Request; -using asapo::RequestHandlerDeleteStream; +using asapo::RequestHandlerDbDeleteStream; using ::asapo::GenericRequestHeader; using asapo::MockDatabase; @@ -56,7 +56,7 @@ namespace { class DbMetaDeleteStreamTests : public Test { public: - RequestHandlerDeleteStream handler{asapo::kDBDataCollectionNamePrefix}; + RequestHandlerDbDeleteStream handler{asapo::kDBDataCollectionNamePrefix}; std::unique_ptr<NiceMock<MockRequest>> mock_request; NiceMock<MockDatabase> mock_db; NiceMock<asapo::MockLogger> mock_logger; @@ -130,7 +130,7 @@ TEST_F(DbMetaDeleteStreamTests, CallsDeleteErrorAlreadyExist) { ExpectDelete(3,&asapo::DBErrorTemplates::kNoRecord); auto err = handler.ProcessRequest(mock_request.get()); - ASSERT_THAT(err, Eq(asapo::DBErrorTemplates::kNoRecord)); + ASSERT_THAT(err, Eq(asapo::ReceiverErrorTemplates::kBadRequest)); } TEST_F(DbMetaDeleteStreamTests, CallsDeleteNoErrorAlreadyExist) { diff --git a/tests/automatic/mongo_db/insert_retrieve/cleanup_linux.sh b/tests/automatic/mongo_db/insert_retrieve/cleanup_linux.sh index adf03ce9569285fb5f692cc9f9c96685cc1268d0..71093a2577be0c3a864a60150290b04338c6f057 100644 --- a/tests/automatic/mongo_db/insert_retrieve/cleanup_linux.sh +++ b/tests/automatic/mongo_db/insert_retrieve/cleanup_linux.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -database_name=data +database_name=data_%2F%20%5C%2E%22%24 echo "db.dropDatabase()" | mongo ${database_name} diff --git a/tests/automatic/mongo_db/insert_retrieve/cleanup_windows.bat b/tests/automatic/mongo_db/insert_retrieve/cleanup_windows.bat index 53b3cbee6ad380a90cb12999ff08e6724fe90d7b..74beec0f222e1aea00977c7b7d75bb6536245e3e 100644 --- a/tests/automatic/mongo_db/insert_retrieve/cleanup_windows.bat +++ b/tests/automatic/mongo_db/insert_retrieve/cleanup_windows.bat @@ -1,4 +1,4 @@ -SET database_name=data +SET database_name=data_%2F%20%5C%2E%22%24 SET mongo_exe="c:\Program Files\MongoDB\Server\4.2\bin\mongo.exe" echo db.dropDatabase() | %mongo_exe% %database_name% diff --git a/tests/automatic/mongo_db/insert_retrieve/insert_retrieve_mongodb.cpp b/tests/automatic/mongo_db/insert_retrieve/insert_retrieve_mongodb.cpp index ac95431b257f15d3cfc9d3279937a7ec042c1fdf..024c97827665611f14205898b453ec232fb3df9d 100644 --- a/tests/automatic/mongo_db/insert_retrieve/insert_retrieve_mongodb.cpp +++ b/tests/automatic/mongo_db/insert_retrieve/insert_retrieve_mongodb.cpp @@ -3,6 +3,8 @@ #include <thread> #include "../../../common/cpp/src/database/mongodb_client.h" +#include "../../../common/cpp/src/database/encoding.h" + #include "testing.h" #include "asapo/common/data_structs.h" @@ -44,17 +46,19 @@ int main(int argc, char* argv[]) { fi.timestamp = std::chrono::system_clock::now(); fi.buf_id = 18446744073709551615ull; fi.source = "host:1234"; - + + auto db_name = R"(data_/ \."$)"; + auto stream_name = R"(bla/test_/\ ."$)"; if (args.keyword != "Notconnected") { - db.Connect("127.0.0.1", "data"); + db.Connect("127.0.0.1", db_name); } - auto err = db.Insert("data_test", fi, false); + auto err = db.Insert(std::string("data_")+stream_name, fi, false); if (args.keyword == "DuplicateID") { Assert(err, "OK"); - err = db.Insert("data_test", fi, false); + err = db.Insert(std::string("data_")+stream_name, fi, false); } std::this_thread::sleep_for(std::chrono::milliseconds(10)); @@ -73,17 +77,17 @@ int main(int argc, char* argv[]) { if (args.keyword == "OK") { // check retrieve and stream delete asapo::MessageMeta fi_db; asapo::MongoDBClient db_new; - db_new.Connect("127.0.0.1", "data"); - err = db_new.GetById("data_test", fi.id, &fi_db); + db_new.Connect("127.0.0.1", db_name); + err = db_new.GetById(std::string("data_")+stream_name, fi.id, &fi_db); M_AssertTrue(fi_db == fi, "get record from db"); M_AssertEq(nullptr, err); - err = db_new.GetById("data_test", 0, &fi_db); + err = db_new.GetById(std::string("data_")+stream_name, 0, &fi_db); Assert(err, "No record"); asapo::StreamInfo info; - err = db.GetStreamInfo("data_test", &info); + err = db.GetStreamInfo(std::string("data_")+stream_name, &info); M_AssertEq(nullptr, err); M_AssertEq(fi.id, info.last_id); @@ -95,20 +99,20 @@ int main(int argc, char* argv[]) { M_AssertEq("ns",info.next_stream); // delete stream - db.Insert("inprocess_test_blabla", fi, false); - db.Insert("inprocess_test_blabla1", fi, false); - db.Insert("acks_test_blabla", fi, false); - db.Insert("acks_test_blabla1", fi, false); - db.DeleteStream("test"); - err = db.GetStreamInfo("data_test", &info); + db.Insert(std::string("inprocess_")+stream_name+"_blabla", fi, false); + db.Insert(std::string("inprocess_")+stream_name+"_blabla1", fi, false); + db.Insert(std::string("acks_")+stream_name+"_blabla", fi, false); + db.Insert(std::string("acks_")+stream_name+"_blabla1", fi, false); + db.DeleteStream(stream_name); + err = db.GetStreamInfo(std::string("data_")+stream_name, &info); M_AssertTrue(info.last_id == 0); - err = db.GetStreamInfo("inprocess_test_blabla", &info); + err = db.GetStreamInfo(std::string("inprocess_")+stream_name+"_blabla", &info); M_AssertTrue(info.last_id == 0); - err = db.GetStreamInfo("inprocess_test_blabla1", &info); + err = db.GetStreamInfo(std::string("inprocess_")+stream_name+"_blabla1", &info); M_AssertTrue(info.last_id == 0); - err = db.GetStreamInfo("acks_test_blabla", &info); + err = db.GetStreamInfo(std::string("acks_")+stream_name+"_blabla", &info); M_AssertTrue(info.last_id == 0); - err = db.GetStreamInfo("acks_test_blabla1", &info); + err = db.GetStreamInfo(std::string("acks_")+stream_name+"_blabla1", &info); M_AssertTrue(info.last_id == 0); err = db.DeleteStream("test1"); M_AssertTrue(err==nullptr);