diff --git a/CHANGELOG.md b/CHANGELOG.md index 0bd981a529541c6d97dc7c94f1119115aabcd3c5..20f2f11a35d6c8d1f10b2240bc772ad1397c84b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ ## 21.06.0 (in progress) +IMPROVEMENTS +* Consumer/Producer API - allow any characters in source/stream/group names BUG FIXES * Consumer API: multiple consumers from same group receive stream finished error diff --git a/CMakeLists.txt b/CMakeLists.txt index 392ff92b91f385cb84e1c21eb7c7efffea856008..8ac928819283903b0d364efd8e614df3c37a434b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,13 +3,13 @@ project(ASAPO) #protocol version changes if one of the microservice API's change set (ASAPO_CONSUMER_PROTOCOL "v0.4") -set (ASAPO_PRODUCER_PROTOCOL "v0.2") +set (ASAPO_PRODUCER_PROTOCOL "v0.3") set (ASAPO_DISCOVERY_API_VER "v0.1") set (ASAPO_AUTHORIZER_API_VER "v0.1") set (ASAPO_BROKER_API_VER "v0.4") set (ASAPO_FILE_TRANSFER_SERVICE_API_VER "v0.1") -set (ASAPO_RECEIVER_API_VER "v0.2") +set (ASAPO_RECEIVER_API_VER "v0.3") set (ASAPO_RDS_API_VER "v0.1") set(CMAKE_CXX_STANDARD 11) diff --git a/CMakeModules/coverage_go.sh b/CMakeModules/coverage_go.sh index 5b01cf48bb047093f0605fb0b115751c1aa38be4..5b74a200569faacd779cddac041a7399aa67c5de 100755 --- a/CMakeModules/coverage_go.sh +++ b/CMakeModules/coverage_go.sh @@ -15,7 +15,7 @@ for pkg in ${PACKAGES[@]} do # echo $pkg go test -coverprofile=$OUT_DIR/coverage.out -tags test $pkg #>/dev/null 2>&1 - tail -n +2 $OUT_DIR/coverage.out | grep -v kubernetes >> $OUT_DIR/coverage-all.out #2>/dev/null + tail -n +2 $OUT_DIR/coverage.out | grep -v -e kubernetes -e _nottested >> $OUT_DIR/coverage-all.out #2>/dev/null done coverage=`go tool cover -func=$OUT_DIR/coverage-all.out | grep total | cut -d ")" -f 2 | cut -d "." -f 1` diff --git a/CMakeModules/testing_cpp.cmake b/CMakeModules/testing_cpp.cmake index 95364659d91c2e92923eee675303f57f51e86506..0c000601b4f9a70f7049e2878c85c8bad6c0c158 100644 --- a/CMakeModules/testing_cpp.cmake +++ b/CMakeModules/testing_cpp.cmake @@ -103,10 +103,11 @@ function(gtest target test_source_files linktarget) endif () add_test(NAME test-${target} COMMAND test-${target}) set_tests_properties(test-${target} PROPERTIES LABELS "unit;all") - message(STATUS "Added test 'test-${target}'") - - if (CMAKE_COMPILER_IS_GNUCXX) + if (ARGN) + LIST(GET ARGN 0 NOCOV) + endif() + if (CMAKE_COMPILER_IS_GNUCXX AND NOT 1${NOCOV} STREQUAL "1nocov") set(COVERAGE_EXCLUDES "*/unittests/*" "*/3d_party/*" "*/python/*") if (ARGN) set(COVERAGE_EXCLUDES ${COVERAGE_EXCLUDES} ${ARGN}) @@ -116,6 +117,8 @@ function(gtest target test_source_files linktarget) COMMAND ${CMAKE_MODULE_PATH}/check_test.sh coverage-${target} ${CMAKE_BINARY_DIR} ${ASAPO_MINIMUM_COVERAGE}) set_tests_properties(coveragetest-${target} PROPERTIES LABELS "coverage;all") + message(STATUS "Added test 'test-${target}-coverage'") + SET_TESTS_PROPERTIES(coveragetest-${target} PROPERTIES DEPENDS test-${target}) set(CMAKE_C_FLAGS ${CMAKE_C_FLAGS} PARENT_SCOPE) set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} PARENT_SCOPE) diff --git a/PROTOCOL-VERSIONS.md b/PROTOCOL-VERSIONS.md index 6481b4fd64fe9836574cddc855b2f13c3b27a1d2..e96812b7a392a8b81bb04305847b7e011b261e4e 100644 --- a/PROTOCOL-VERSIONS.md +++ b/PROTOCOL-VERSIONS.md @@ -1,14 +1,15 @@ ### Producer Protocol | Release | used by client | Supported by server | Status | | ------------ | ------------------- | -------------------- | ---------------- | -| v0.2 | 21.03.2 - 21.03.2 | 21.03.2 - 21.03.2 | Current version | -| v0.1 | 21.03.0 - 21.03.1 | 21.03.0 - 21.03.2 | Deprecates from 01.06.2022 | +| v0.3 | 21.06.0 - 21.06.0 | 21.06.0 - 21.06.0 | Current version | +| v0.2 | 21.03.2 - 21.03.2 | 21.03.2 - 21.06.0 | Deprecates from 01.07.2022 | +| v0.1 | 21.03.0 - 21.03.1 | 21.03.0 - 21.06.0 | Deprecates from 01.06.2022 | ### Consumer Protocol | Release | used by client | Supported by server | Status | | ------------ | ------------------- | -------------------- | ---------------- | | v0.4 | 21.06.0 - 21.06.0 | 21.06.0 - 21.06.0 | Current version | -| v0.3 | 21.03.3 - 21.03.3 | 21.03.3 - 21.06.0 | Deprecates from 01.06.2022 | +| v0.3 | 21.03.3 - 21.03.3 | 21.03.3 - 21.06.0 | Deprecates from 01.07.2022 | | v0.2 | 21.03.2 - 21.03.2 | 21.03.2 - 21.06.0 | Deprecates from 01.06.2022 | | v0.1 | 21.03.0 - 21.03.1 | 21.03.0 - 21.06.0 | Deprecates from 01.06.2022 | diff --git a/VERSIONS.md b/VERSIONS.md index 37c226d10fb37341d486979d134de44a61b5ff69..3196aab3c56a46b9bbca1616925caba94e5e9d23 100644 --- a/VERSIONS.md +++ b/VERSIONS.md @@ -2,8 +2,9 @@ | Release | API changed\*\* | Protocol | Supported by server from/to | Status |Comment| | ------------ | ----------- | -------- | ------------------------- | --------------------- | ------- | -| 21.03.3 | No | v0.2 | 21.03.2/21.03.3 | current version |bugfix in server| -| 21.03.2 | Yes | v0.2 | 21.03.2/21.03.3 | current version |bugfixes, add delete_stream| +| 21.06.0 | Yes | v0.3 | 21.06.0/21.06.0 | current version |arbitrary characters| +| 21.03.3 | No | v0.2 | 21.03.2/21.03.3 | deprecates 01.07.2022 |bugfix in server| +| 21.03.2 | Yes | v0.2 | 21.03.2/21.03.3 | deprecates 01.07.2022 |bugfixes, add delete_stream| | 21.03.1 | No | v0.1 | 21.03.0/21.03.3 | deprecates 01.06.2022 |bugfix in server| | 21.03.0 | Yes | v0.1 | 21.03.0/21.03.3 | | | @@ -11,7 +12,7 @@ | Release | API changed\*\* | Protocol | Supported by server from/to | Status |Comment| | ------------ | ----------- | --------- | ------------------------- | ---------------- | ------- | -| 21.06.0 | No* | v0.4 | 21.06.0/21.06.0 | current version |bugfixes | +| 21.06.0 | Yes | v0.4 | 21.06.0/21.06.0 | current version |arbitrary characters, bugfixes | | 21.03.3 | Yes | v0.3 | 21.03.3/21.06.0 | deprecates 01.06.2022 |bugfix in server, error type for dublicated ack| | 21.03.2 | Yes | v0.2 | 21.03.2/21.06.0 | deprecates 01.06.2022 |bugfixes, add delete_stream| | 21.03.1 | No | v0.1 | 21.03.0/21.06.0 | deprecates 01.06.2022 |bugfix in server| diff --git a/authorizer/src/asapo_authorizer/server/authorize.go b/authorizer/src/asapo_authorizer/server/authorize.go index b47b7adcef4b7287bd5312019dea9055be3945ab..26f8c92d19d59846303a666ae63b10089e3e01d0 100644 --- a/authorizer/src/asapo_authorizer/server/authorize.go +++ b/authorizer/src/asapo_authorizer/server/authorize.go @@ -25,12 +25,16 @@ type authorizationRequest struct { } func getSourceCredentials(request authorizationRequest) (SourceCredentials, error) { - vals := strings.Split(request.SourceCredentials, "%") - if len(vals) != 5 { + + vals := strings.Split(request.SourceCredentials, "%") + nvals:=len(vals) + if nvals < 5 { return SourceCredentials{}, errors.New("cannot get source credentials from " + request.SourceCredentials) } - creds := SourceCredentials{vals[1], vals[2], vals[3], vals[4],vals[0]} + + creds := SourceCredentials{Type:vals[0], BeamtimeId: vals[1], Beamline: vals[2], Token:vals[nvals-1]} + creds.DataSource=strings.Join(vals[3:nvals-1],"%") if creds.DataSource == "" { creds.DataSource = "detector" } diff --git a/authorizer/src/asapo_authorizer/server/authorize_test.go b/authorizer/src/asapo_authorizer/server/authorize_test.go index 11c072156664e6d4eb959164c0dd4cf0a92869be..12268d73dd9d476f46e6a2e187ea3a7b6d381a94 100644 --- a/authorizer/src/asapo_authorizer/server/authorize_test.go +++ b/authorizer/src/asapo_authorizer/server/authorize_test.go @@ -91,6 +91,8 @@ var credTests = [] struct { {"raw%%beamline%source%token", SourceCredentials{"auto","beamline","source","token","raw"},true,"empty beamtime"}, {"raw%asapo_test%%source%token", SourceCredentials{"asapo_test","auto","source","token","raw"},true,"empty bealine"}, {"raw%%%source%token", SourceCredentials{},false,"both empty"}, + {"processed%asapo_test%beamline%source%blabla%token", SourceCredentials{"asapo_test","beamline","source%blabla","token","processed"},true,"% in source"}, + {"processed%asapo_test%beamline%source%blabla%", SourceCredentials{"asapo_test","beamline","source%blabla","","processed"},true,"% in source, no token"}, } func TestSplitCreds(t *testing.T) { @@ -100,7 +102,7 @@ func TestSplitCreds(t *testing.T) { creds,err := getSourceCredentials(request) if test.ok { assert.Nil(t,err) - assert.Equal(t,creds,test.cred,test.message) + assert.Equal(t,test.cred,creds,test.message) } else { assert.NotNil(t,err,test.message) } diff --git a/broker/src/asapo_broker/database/encoding.go b/broker/src/asapo_broker/database/encoding.go new file mode 100644 index 0000000000000000000000000000000000000000..7f08397975d94553a61a87dad251dbdb55ec2864 --- /dev/null +++ b/broker/src/asapo_broker/database/encoding.go @@ -0,0 +1,99 @@ +package database + +import ( + "asapo_common/utils" + "net/url" +) + +const max_encoded_source_size = 63 +const max_encoded_stream_size = 100 +const max_encoded_group_size = 50 + +func shouldEscape(c byte, db bool) bool { + if c == '$' || c == ' ' || c == '%' { + return true + } + if !db { + return false + } + + switch c { + case '\\', '/', '.', '"': + return true + } + return false +} + +const upperhex = "0123456789ABCDEF" + +func escape(s string, db bool) string { + hexCount := 0 + for i := 0; i < len(s); i++ { + c := s[i] + if shouldEscape(c, db) { + hexCount++ + } + } + + if hexCount == 0 { + return s + } + + var buf [64]byte + var t []byte + + required := len(s) + 2*hexCount + if required <= len(buf) { + t = buf[:required] + } else { + t = make([]byte, required) + } + + j := 0 + for i := 0; i < len(s); i++ { + switch c := s[i]; { + case shouldEscape(c, db): + t[j] = '%' + t[j+1] = upperhex[c>>4] + t[j+2] = upperhex[c&15] + j += 3 + default: + t[j] = s[i] + j++ + } + } + return string(t) +} + +func encodeStringForDbName(original string) (result string) { + return escape(original, true) +} + + +func decodeString(original string) (result string) { + result,_ = url.PathUnescape(original) + return result +} + +func encodeStringForColName(original string) (result string) { + return escape(original, false) +} + +func encodeRequest(request *Request) error { + request.DbName = encodeStringForDbName(request.DbName) + if len(request.DbName)> max_encoded_source_size { + return &DBError{utils.StatusWrongInput, "source name is too long"} + } + + request.DbCollectionName = encodeStringForColName(request.DbCollectionName) + if len(request.DbCollectionName)> max_encoded_stream_size { + return &DBError{utils.StatusWrongInput, "stream name is too long"} + } + + request.GroupId = encodeStringForColName(request.GroupId) + if len(request.GroupId)> max_encoded_group_size { + return &DBError{utils.StatusWrongInput, "group id is too long"} + } + + return nil +} diff --git a/broker/src/asapo_broker/database/encoding_test.go b/broker/src/asapo_broker/database/encoding_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1def90c99f6a2268883530be39b62fbb01eabb96 --- /dev/null +++ b/broker/src/asapo_broker/database/encoding_test.go @@ -0,0 +1,82 @@ +package database + +import ( + "asapo_common/utils" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" +) + +func TestEncoding(t *testing.T) { + stream := `ss$` + source := `ads%&%41.sss` + streamEncoded := encodeStringForColName(stream) + sourceEncoded := encodeStringForDbName(source) + streamDecoded := decodeString(streamEncoded) + sourceDecoded := decodeString(sourceEncoded) + assert.Equal(t, streamDecoded, stream) + assert.Equal(t, sourceDecoded, source) + + r := Request{ + DbName: source, + DbCollectionName: stream, + GroupId: stream, + Op: "", + DatasetOp: false, + MinDatasetSize: 0, + ExtraParam: "", + } + err := encodeRequest(&r) + assert.Equal(t, r.DbCollectionName, streamEncoded) + assert.Equal(t, r.GroupId, streamEncoded) + assert.Equal(t, r.DbName, sourceEncoded) + + assert.Nil(t, err) +} + +var encodeTests = []struct { + streamSize int + groupSize int + sourceSize int + ok bool + message string +}{ + {max_encoded_stream_size, max_encoded_group_size, max_encoded_source_size, true, "ok"}, + {max_encoded_stream_size + 1, max_encoded_group_size, max_encoded_source_size, false, "stream"}, + {max_encoded_stream_size, max_encoded_group_size + 1, max_encoded_source_size, false, "group"}, + {max_encoded_stream_size, max_encoded_group_size, max_encoded_source_size + 1, false, "source"}, +} + +func RandomString(n int) string { + var letter = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + + b := make([]rune, n) + for i := range b { + b[i] = letter[rand.Intn(len(letter))] + } + return string(b) +} + +func TestEncodingTooLong(t *testing.T) { + for _, test := range encodeTests { + stream := RandomString(test.streamSize) + group := RandomString(test.groupSize) + source := RandomString(test.sourceSize) + r := Request{ + DbName: source, + DbCollectionName: stream, + GroupId: group, + Op: "", + DatasetOp: false, + MinDatasetSize: 0, + ExtraParam: "", + } + err := encodeRequest(&r) + if test.ok { + assert.Nil(t, err, test.message) + } else { + assert.Equal(t, utils.StatusWrongInput, err.(*DBError).Code) + assert.Contains(t,err.Error(),test.message,test.message) + } + } +} diff --git a/broker/src/asapo_broker/database/mongodb.go b/broker/src/asapo_broker/database/mongodb.go index 5491086a02aebf133e75639a9498cecc956e9c31..75c885f30ee4290bf9d5980e2f23ebd25b4e9000 100644 --- a/broker/src/asapo_broker/database/mongodb.go +++ b/broker/src/asapo_broker/database/mongodb.go @@ -8,7 +8,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" @@ -67,7 +66,6 @@ const meta_collection_name = "meta" const pointer_collection_name = "current_location" const pointer_field_name = "current_pointer" const no_session_msg = "database client not created" -const wrong_id_type = "wrong id type" const already_connected_msg = "already connected" const finish_stream_keyword = "asapo_finish_stream" @@ -219,7 +217,7 @@ func (db *Mongodb) setCounter(request Request, ind int) (err error) { } func (db *Mongodb) errorWhenCannotIncrementField(request Request, max_ind int) (err error) { - if res, err := db.getRecordFromDb(request, max_ind, max_ind);err == nil { + if res, err := db.getRecordFromDb(request, max_ind, max_ind); err == nil { if err := checkStreamFinished(request, max_ind, max_ind, res); err != nil { return err } @@ -227,7 +225,6 @@ func (db *Mongodb) errorWhenCannotIncrementField(request Request, max_ind int) ( return &DBError{utils.StatusNoData, encodeAnswer(max_ind, max_ind, "")} } - func (db *Mongodb) incrementField(request Request, max_ind int, res interface{}) (err error) { update := bson.M{"$inc": bson.M{pointer_field_name: 1}} opts := options.FindOneAndUpdate().SetUpsert(true).SetReturnDocument(options.After) @@ -244,7 +241,7 @@ func (db *Mongodb) incrementField(request Request, max_ind int, res interface{}) if err2 := c.FindOneAndUpdate(context.TODO(), q, update, opts).Decode(res); err2 == nil { return nil } - return db.errorWhenCannotIncrementField(request,max_ind) + return db.errorWhenCannotIncrementField(request, max_ind) } return &DBError{utils.StatusTransactionInterrupted, err.Error()} } @@ -595,7 +592,6 @@ func checkStreamFinished(request Request, id, id_max int, data map[string]interf return nil } r, ok := ExtractMessageRecord(data) - fmt.Println(r,ok) if !ok || !r.FinishedStream { return nil } @@ -855,9 +851,17 @@ func (db *Mongodb) deleteDocumentsInCollection(request Request, collection strin return err } +func escapeQuery(query string )(res string) { + chars := `\-[]{}()*+?.,^$|#` + for _, char := range chars { + query = strings.ReplaceAll(query,string(char),`\`+string(char)) + } + return query +} + func (db *Mongodb) deleteCollectionsWithPrefix(request Request, prefix string) error { cols, err := db.client.Database(request.DbName).ListCollectionNames(context.TODO(), bson.M{"name": bson.D{ - {"$regex", primitive.Regex{Pattern: "^" + prefix, Options: "i"}}}}) + {"$regex", primitive.Regex{Pattern: "^" + escapeQuery(prefix), Options: "i"}}}}) if err != nil { return err } @@ -881,7 +885,7 @@ func (db *Mongodb) deleteServiceMeta(request Request) error { if err != nil { return err } - return db.deleteDocumentsInCollection(request, pointer_collection_name, "_id", ".*_"+request.DbCollectionName+"$") + return db.deleteDocumentsInCollection(request, pointer_collection_name, "_id", ".*_"+escapeQuery(request.DbCollectionName)+"$") } func (db *Mongodb) deleteStream(request Request) ([]byte, error) { @@ -1000,11 +1004,16 @@ func (db *Mongodb) getStreams(request Request) ([]byte, error) { return json.Marshal(&rec) } + func (db *Mongodb) ProcessRequest(request Request) (answer []byte, err error) { if err := db.checkDatabaseOperationPrerequisites(request); err != nil { return nil, err } + if err := encodeRequest(&request); err != nil { + return nil, err + } + switch request.Op { case "next": return db.getNextRecord(request) diff --git a/broker/src/asapo_broker/database/mongodb_streams.go b/broker/src/asapo_broker/database/mongodb_streams.go index 278ef3c57062196067b1d78c7814b0ecfcfba70e..243df816d1e182c2dd767e5dc0db106a57ee40d8 100644 --- a/broker/src/asapo_broker/database/mongodb_streams.go +++ b/broker/src/asapo_broker/database/mongodb_streams.go @@ -57,7 +57,8 @@ func readStreams(db *Mongodb, db_name string) (StreamsRecord, error) { var rec = StreamsRecord{[]StreamInfo{}} for _, coll := range result { if strings.HasPrefix(coll, data_collection_name_prefix) { - si := StreamInfo{Name: strings.TrimPrefix(coll, data_collection_name_prefix)} + sNameEncoded:= strings.TrimPrefix(coll, data_collection_name_prefix) + si := StreamInfo{Name: decodeString(sNameEncoded)} rec.Streams = append(rec.Streams, si) } } @@ -88,7 +89,7 @@ func findStreamAmongCurrent(currentStreams []StreamInfo, record StreamInfo) (int } func fillInfoFromEarliestRecord(db *Mongodb, db_name string, rec *StreamsRecord, record StreamInfo, i int) error { - res, err := db.getEarliestRawRecord(db_name, record.Name) + res, err := db.getEarliestRawRecord(db_name, encodeStringForColName(record.Name)) if err != nil { return err } @@ -102,7 +103,7 @@ func fillInfoFromEarliestRecord(db *Mongodb, db_name string, rec *StreamsRecord, } func fillInfoFromLastRecord(db *Mongodb, db_name string, rec *StreamsRecord, record StreamInfo, i int) error { - res, err := db.getLastRawRecord(db_name, record.Name) + res, err := db.getLastRawRecord(db_name, encodeStringForColName(record.Name)) if err != nil { return err } diff --git a/broker/src/asapo_broker/database/mongodb_test.go b/broker/src/asapo_broker/database/mongodb_test.go index 72471969075187ea4e5226bec89a167a63f26d07..9b4742ff9c0146509df402fe5295a946806b54c5 100644 --- a/broker/src/asapo_broker/database/mongodb_test.go +++ b/broker/src/asapo_broker/database/mongodb_test.go @@ -36,6 +36,11 @@ const groupId = "bid2a5auidddp1vl71d0" const metaID = 0 const metaID_str = "0" +const badSymbolsDb = `/\."$` +const badSymbolsCol = `$` +const badSymbolsDbEncoded = "%2F%5C%2E%22%24" +const badSymbolsColEncoded ="%24" + var empty_next = map[string]string{"next_stream": ""} var rec1 = TestRecord{1, empty_next, "aaa", 0} @@ -63,6 +68,15 @@ func cleanup() { db.Close() } +func cleanupWithName(name string) { + if db.client == nil { + return + } + db.dropDatabase(name) + db.Close() +} + + // these are the integration tests. They assume mongo db is runnig on 127.0.0.1:27027 // test names should contain MongoDB*** so that go test could find them: // go_integration_test(${TARGET_NAME}-connectdb "./..." "MongoDBConnect") @@ -882,6 +896,10 @@ var testsStreams = []struct { StreamsRecord{[]StreamInfo{StreamInfo{Name: "ss1", Timestamp: 0, LastId: 2, TimestampLast: 1}, StreamInfo{Name: "ss2", Timestamp: 1, LastId: 3, TimestampLast: 2}}}, "two streams", true}, {"ss2", []Stream{{"ss1", []TestRecord{rec1, rec2}}, {"ss2", []TestRecord{rec2, rec3}}}, StreamsRecord{[]StreamInfo{StreamInfo{Name: "ss2", Timestamp: 1, LastId: 3, TimestampLast: 2}}}, "with from", true}, + {"", []Stream{{"ss1$", []TestRecord{rec2, rec1}}}, + StreamsRecord{[]StreamInfo{StreamInfo{Name: "ss1$", Timestamp: 0, LastId: 2, TimestampLast: 1}}}, "one stream encoded", true}, + {"ss2$", []Stream{{"ss1$", []TestRecord{rec1, rec2}}, {"ss2$", []TestRecord{rec2, rec3}}}, StreamsRecord{[]StreamInfo{StreamInfo{Name: "ss2$", Timestamp: 1, LastId: 3, TimestampLast: 2}}}, "with from encoded", true}, + } func TestMongoDBListStreams(t *testing.T) { @@ -889,7 +907,7 @@ func TestMongoDBListStreams(t *testing.T) { db.Connect(dbaddress) for _, stream := range test.streams { for _, rec := range stream.records { - db.insertRecord(dbname, stream.name, &rec) + db.insertRecord(dbname, encodeStringForColName(stream.name), &rec) } } var rec_streams_expect, _ = json.Marshal(test.expectedStreams) @@ -1196,14 +1214,21 @@ var testsDeleteStream = []struct { }{ {"test", "{\"ErrorOnNotExist\":true,\"DeleteMeta\":true}", true,false, "delete stream"}, {"test", "{\"ErrorOnNotExist\":false,\"DeleteMeta\":true}", true, true,"delete stream"}, + {`test$/\ .%&?*#'`, "{\"ErrorOnNotExist\":false,\"DeleteMeta\":true}", true, true,"delete stream"}, + } func TestDeleteStreams(t *testing.T) { + defer cleanup() for _, test := range testsDeleteStream { db.Connect(dbaddress) - db.insertRecord(dbname, test.stream, &rec_finished11) - - _, err := db.ProcessRequest(Request{DbName: dbname, DbCollectionName: test.stream, GroupId: "", Op: "delete_stream", ExtraParam: test.params}) + db.insertRecord(dbname, encodeStringForColName(test.stream), &rec1) + db.ProcessRequest(Request{DbName: dbname, DbCollectionName: test.stream, GroupId: "123", Op: "next"}) + query_str := "{\"Id\":1,\"Op\":\"ackmessage\"}" + request := Request{DbName: dbname, DbCollectionName: test.stream, GroupId: groupId, Op: "ackmessage", ExtraParam: query_str} + _, err := db.ProcessRequest(request) + assert.Nil(t, err, test.message) + _, err = db.ProcessRequest(Request{DbName: dbname, DbCollectionName: test.stream, GroupId: "", Op: "delete_stream", ExtraParam: test.params}) if test.ok { rec, err := streams.getStreams(&db, Request{DbName: dbname, ExtraParam: ""}) acks_exist,_:= db.collectionExist(Request{DbName: dbname, ExtraParam: ""},acks_collection_name_prefix+test.stream) @@ -1223,3 +1248,36 @@ func TestDeleteStreams(t *testing.T) { } } } + + +var testsEncodings = []struct { + dbname string + collection string + group string + dbname_indb string + collection_indb string + group_indb string + message string + ok bool +}{ + {"dbname", "col", "group", "dbname","col","group", "no encoding",true}, + {"dbname"+badSymbolsDb, "col", "group", "dbname"+badSymbolsDbEncoded,"col","group", "symbols in db",true}, + {"dbname", "col"+badSymbolsCol, "group"+badSymbolsCol, "dbname","col"+badSymbolsColEncoded,"group"+badSymbolsColEncoded, "symbols in col",true}, + {"dbname"+badSymbolsDb, "col"+badSymbolsCol, "group"+badSymbolsCol, "dbname"+badSymbolsDbEncoded,"col"+badSymbolsColEncoded,"group"+badSymbolsColEncoded, "symbols in col and db",true}, + +} + +func TestMongoDBEncodingOK(t *testing.T) { + for _, test := range testsEncodings { + db.Connect(dbaddress) + db.insertRecord(test.dbname_indb, test.collection_indb, &rec1) + res, err := db.ProcessRequest(Request{DbName: test.dbname, DbCollectionName: test.collection, GroupId: test.group, Op: "next"}) + if test.ok { + assert.Nil(t, err, test.message) + assert.Equal(t, string(rec1_expect), string(res), test.message) + } else { + assert.Equal(t, utils.StatusWrongInput, err.(*DBError).Code, test.message) + } + cleanupWithName(test.dbname_indb) + } +} \ No newline at end of file diff --git a/broker/src/asapo_broker/server/get_commands_test.go b/broker/src/asapo_broker/server/get_commands_test.go index c472ddb4c3948004c412ea6267da70a582a72084..40c41c2b6bcb9ef70fd2d99567d7ab25caf1cb7d 100644 --- a/broker/src/asapo_broker/server/get_commands_test.go +++ b/broker/src/asapo_broker/server/get_commands_test.go @@ -6,6 +6,8 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "net/http" + "net/url" + "strings" "testing" ) @@ -34,23 +36,24 @@ func TestGetCommandsTestSuite(t *testing.T) { var testsGetCommand = []struct { command string + source string stream string groupid string reqString string queryParams string externalParam string }{ - {"last", expectedStream, "", expectedStream + "/0/last","","0"}, - {"id", expectedStream, "", expectedStream + "/0/1","","1"}, - {"meta", "default", "", "default/0/meta/0","","0"}, - {"nacks", expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/nacks","","0_0"}, - {"next", expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/next","",""}, - {"next", expectedStream, expectedGroupID, expectedStream + "/" + + {"last", expectedSource,expectedStream, "", expectedStream + "/0/last","","0"}, + {"id", expectedSource,expectedStream, "", expectedStream + "/0/1","","1"}, + {"meta", expectedSource,"default", "", "default/0/meta/0","","0"}, + {"nacks",expectedSource, expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/nacks","","0_0"}, + {"next", expectedSource,expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/next","",""}, + {"next", expectedSource,expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/next","&resend_nacks=true&delay_ms=10000&resend_attempts=3","10000_3"}, - {"size", expectedStream, "", expectedStream + "/size","",""}, - {"size", expectedStream, "", expectedStream + "/size","&incomplete=true","true"}, - {"streams", "0", "", "0/streams","","_"}, - {"lastack", expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/lastack","",""}, + {"size", expectedSource,expectedStream, "", expectedStream + "/size","",""}, + {"size",expectedSource, expectedStream, "", expectedStream + "/size","&incomplete=true","true"}, + {"streams",expectedSource, "0", "", "0/streams","","_"}, + {"lastack", expectedSource,expectedStream, expectedGroupID, expectedStream + "/" + expectedGroupID + "/lastack","",""}, } @@ -58,8 +61,33 @@ func (suite *GetCommandsTestSuite) TestGetCommandsCallsCorrectRoutine() { for _, test := range testsGetCommand { suite.mock_db.On("ProcessRequest", database.Request{DbName: expectedDBName, DbCollectionName: test.stream, GroupId: test.groupid, Op: test.command, ExtraParam: test.externalParam}).Return([]byte("Hello"), nil) logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("processing request "+test.command))) - w := doRequest("/beamtime/" + expectedBeamtimeId + "/" + expectedSource + "/" + test.reqString+correctTokenSuffix+test.queryParams) + w := doRequest("/beamtime/" + expectedBeamtimeId + "/" + test.source + "/" + test.reqString+correctTokenSuffix+test.queryParams) suite.Equal(http.StatusOK, w.Code, test.command+ " OK") suite.Equal("Hello", string(w.Body.Bytes()), test.command+" sends data") } } + +func (suite *GetCommandsTestSuite) TestGetCommandsCorrectlyProcessedEncoding() { + badSymbols:="%$&./\\_$&\"" + for _, test := range testsGetCommand { + newstream := test.stream+badSymbols + newsource := test.source+badSymbols + newgroup :="" + if test.groupid!="" { + newgroup = test.groupid+badSymbols + } + encodedStream:=url.PathEscape(newstream) + encodedSource:=url.PathEscape(newsource) + encodedGroup:=url.PathEscape(newgroup) + test.reqString = strings.Replace(test.reqString,test.groupid,encodedGroup,1) + test.reqString = strings.Replace(test.reqString,test.source,encodedSource,1) + test.reqString = strings.Replace(test.reqString,test.stream,encodedStream,1) + dbname := expectedBeamtimeId + "_" + newsource + suite.mock_db.On("ProcessRequest", database.Request{DbName: dbname, DbCollectionName: newstream, GroupId: newgroup, Op: test.command, ExtraParam: test.externalParam}).Return([]byte("Hello"), nil) + logger.MockLog.On("Debug", mock.MatchedBy(containsMatcher("processing request "+test.command))) + w := doRequest("/beamtime/" + expectedBeamtimeId + "/" + encodedSource + "/" + test.reqString+correctTokenSuffix+test.queryParams) + suite.Equal(http.StatusOK, w.Code, test.command+ " OK") + suite.Equal("Hello", string(w.Body.Bytes()), test.command+" sends data") + } +} + diff --git a/broker/src/asapo_broker/server/process_request.go b/broker/src/asapo_broker/server/process_request.go index 7e001f3397feef9dc0384da6d38aa21078caa5e3..23fe151a8865dc7debecc7c770a561fba8eb803b 100644 --- a/broker/src/asapo_broker/server/process_request.go +++ b/broker/src/asapo_broker/server/process_request.go @@ -8,20 +8,33 @@ import ( "asapo_common/version" "github.com/gorilla/mux" "net/http" + "net/url" ) +func readFromMapUnescaped(key string, vars map[string]string) (val string,ok bool) { + if val, ok = vars[key];!ok { + return + } + var err error + if val, err = url.PathUnescape(val);err!=nil { + return "",false + } + return +} + func extractRequestParameters(r *http.Request, needGroupID bool) (string, string, string, string, bool) { vars := mux.Vars(r) db_name, ok1 := vars["beamtime"] - - datasource, ok3 := vars["datasource"] - stream, ok4 := vars["stream"] + datasource, ok3 := readFromMapUnescaped("datasource",vars) + stream, ok4 := readFromMapUnescaped("stream",vars) ok2 := true group_id := "" if needGroupID { - group_id, ok2 = vars["groupid"] + group_id, ok2 = readFromMapUnescaped("groupid",vars) } + + return db_name, datasource, stream, group_id, ok1 && ok2 && ok3 && ok4 } @@ -34,22 +47,6 @@ func IsLetterOrNumbers(s string) bool { return true } - -func checkGroupID(w http.ResponseWriter, needGroupID bool, group_id string, db_name string, op string) bool { - if !needGroupID { - return true - } - if len(group_id) > 0 && len (group_id) < 100 && IsLetterOrNumbers(group_id) { - return true - } - err_str := "wrong groupid name, check length or allowed charecters in " + group_id - log_str := "processing get " + op + " request in " + db_name + " at " + settings.GetDatabaseServer() + ": " + err_str - logger.Error(log_str) - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte(err_str)) - return false -} - func checkBrokerApiVersion(w http.ResponseWriter, r *http.Request) bool { _, ok := utils.PrecheckApiVersion(w, r, version.GetBrokerApiVersion()) return ok @@ -77,10 +74,6 @@ func processRequest(w http.ResponseWriter, r *http.Request, op string, extra_par return } - if !checkGroupID(w, needGroupID, group_id, db_name, op) { - return - } - request := database.Request{} request.DbName = db_name+"_"+datasource request.Op = op diff --git a/broker/src/asapo_broker/server/process_request_test.go b/broker/src/asapo_broker/server/process_request_test.go index 97248769155270c0bd33a7b194b25c2721169576..5204c1c540f3ca2535f90a78ead3ebb0b72c0da9 100644 --- a/broker/src/asapo_broker/server/process_request_test.go +++ b/broker/src/asapo_broker/server/process_request_test.go @@ -205,12 +205,6 @@ func (suite *ProcessRequestTestSuite) TestProcessRequestAddsCounter() { suite.Equal(1, statistics.GetCounter(), "ProcessRequest increases counter") } -func (suite *ProcessRequestTestSuite) TestProcessRequestWrongGroupID() { - logger.MockLog.On("Error", mock.MatchedBy(containsMatcher("wrong groupid"))) - w := doRequest("/beamtime/" + expectedBeamtimeId + "/" + expectedSource + "/" + expectedStream + "/" + wrongGroupID + "/next" + correctTokenSuffix) - suite.Equal(http.StatusBadRequest, w.Code, "wrong group id") -} - func (suite *ProcessRequestTestSuite) TestProcessRequestAddsDataset() { expectedRequest := database.Request{DbName: expectedDBName, DbCollectionName: expectedStream, GroupId: expectedGroupID, DatasetOp: true, Op: "next"} diff --git a/common/cpp/include/asapo/database/database.h b/common/cpp/include/asapo/database/database.h index d14861ddefc987781fb0adf2db1a10ed026a5d01..fbe45cccb61c9069d47582463e0a2eda39c5af93 100644 --- a/common/cpp/include/asapo/database/database.h +++ b/common/cpp/include/asapo/database/database.h @@ -11,7 +11,6 @@ namespace asapo { constexpr char kDBDataCollectionNamePrefix[] = "data"; constexpr char kDBMetaCollectionName[] = "meta"; - class Database { public: virtual Error Connect(const std::string& address, const std::string& database) = 0; diff --git a/common/cpp/include/asapo/database/db_error.h b/common/cpp/include/asapo/database/db_error.h index b02c60a7129988fadbd6494826a66bd10a94ee6a..6128dcbbd76c50f28a7d8c1c5430bb35345b6734 100644 --- a/common/cpp/include/asapo/database/db_error.h +++ b/common/cpp/include/asapo/database/db_error.h @@ -16,7 +16,8 @@ enum class DBErrorType { kAlreadyConnected, kBadAddress, kMemoryError, - kNoRecord + kNoRecord, + kWrongInput }; using DBError = ServiceError<DBErrorType, ErrorType::kDBError>; @@ -28,6 +29,10 @@ auto const kNoRecord = DBErrorTemplate { "No record", DBErrorType::kNoRecord }; +auto const kWrongInput = DBErrorTemplate { + "Wrong input", DBErrorType::kWrongInput +}; + auto const kNotConnected = DBErrorTemplate { "Not connected", DBErrorType::kNotConnected diff --git a/common/cpp/include/asapo/http_client/http_client.h b/common/cpp/include/asapo/http_client/http_client.h index 3a41ea96b28013c655b5d51ce4abcee6c80977a4..5733c68f52140f375b1a19579e81e819f0fd4e1e 100644 --- a/common/cpp/include/asapo/http_client/http_client.h +++ b/common/cpp/include/asapo/http_client/http_client.h @@ -21,6 +21,7 @@ class HttpClient { virtual Error Post(const std::string& uri, const std::string& cookie, const std::string& input_data, std::string output_file_name, HttpCode* response_code) const noexcept = 0; + virtual std::string UrlEscape(const std::string& uri) const noexcept = 0; virtual ~HttpClient() = default; }; diff --git a/common/cpp/include/asapo/unittests/MockHttpClient.h b/common/cpp/include/asapo/unittests/MockHttpClient.h index 41a5b5d232e2146c7851928ba503792cb959e8f1..3b41fcba69519d8890a2ab69cd217a85dc5e68fb 100644 --- a/common/cpp/include/asapo/unittests/MockHttpClient.h +++ b/common/cpp/include/asapo/unittests/MockHttpClient.h @@ -37,6 +37,13 @@ class MockHttpClient : public HttpClient { }; + std::string UrlEscape(const std::string& uri) const noexcept override { + return UrlEscape_t(uri); + } + + MOCK_CONST_METHOD1(UrlEscape_t, std::string(const std::string& uri)); + + MOCK_CONST_METHOD3(Get_t, std::string(const std::string& uri, HttpCode* code, ErrorInterface** err)); MOCK_CONST_METHOD5(Post_t, diff --git a/common/cpp/src/database/CMakeLists.txt b/common/cpp/src/database/CMakeLists.txt index 261134f12893116d9577afb3db0d0a1eadac4bb5..eda63f850608b08872b79c147ee887f7852a473d 100644 --- a/common/cpp/src/database/CMakeLists.txt +++ b/common/cpp/src/database/CMakeLists.txt @@ -1,6 +1,7 @@ set(TARGET_NAME database) set(SOURCE_FILES mongodb_client.cpp + encoding.cpp database.cpp) ################################ @@ -16,3 +17,15 @@ target_include_directories(${TARGET_NAME} PUBLIC ${ASAPO_CXX_COMMON_INCLUDE_DIR} PUBLIC "${MONGOC_STATIC_INCLUDE_DIRS}") target_link_libraries (${TARGET_NAME} PRIVATE "${MONGOC_STATIC_LIBRARIES}") target_compile_definitions (${TARGET_NAME} PRIVATE "${MONGOC_STATIC_DEFINITIONS}") + + +################################ +# Testing +################################ +set(TEST_SOURCE_FILES ../../unittests/database/test_encoding.cpp) + +set(TEST_LIBRARIES "${TARGET_NAME}") + +include_directories(${ASAPO_CXX_COMMON_INCLUDE_DIR}) +gtest(${TARGET_NAME} "${TEST_SOURCE_FILES}" "${TEST_LIBRARIES}" "nocov") + diff --git a/common/cpp/src/database/encoding.cpp b/common/cpp/src/database/encoding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..533c3748ea7fcde43b29ba11068ed370b1e53412 --- /dev/null +++ b/common/cpp/src/database/encoding.cpp @@ -0,0 +1,137 @@ +#include "encoding.h" +#include <string.h> +#include <stdio.h> +#include <memory> + +namespace asapo { + +bool ShouldEscape(char c, bool db) { + if (c == '$' || c == ' ' || c == '%') { + return true; + } + if (!db) { + return false; + } + + switch (c) { + case '/': + case '\\': + case '.': + case '"': + return true; + } + return false; +} + +const std::string upperhex = "0123456789ABCDEF"; + +std::string Escape(const std::string& s, bool db) { + auto hexCount = 0; + for (auto i = 0; i < s.size(); i++) { + char c = s[i]; + if (ShouldEscape(c, db)) { + hexCount++; + } + } + + if (hexCount == 0) { + return s; + } + + std::string res; + res.reserve(s.size() + 2 * hexCount); + for (auto i = 0; i < s.size(); i++) { + auto c = s[i]; + if (ShouldEscape(c, db)) { + res.push_back('%'); + res.push_back(upperhex[c >> 4]); + res.push_back(upperhex[c & 15]); + } else { + res.push_back(c); + } + } + return res; +} + +inline int ishex(int x) { + return (x >= '0' && x <= '9') || + (x >= 'a' && x <= 'f') || + (x >= 'A' && x <= 'F'); +} + +int decode(const char* s, char* dec) { + char* o; + const char* end = s + strlen(s); + int c; + + for (o = dec; s <= end; o++) { + c = *s++; +// if (c == '+') c = ' '; + if (c == '%' && (!ishex(*s++) || + !ishex(*s++) || + !sscanf(s - 2, "%2x", &c))) + return -1; + if (dec) *o = c; + } + + return o - dec; +} + +std::string EncodeDbName(const std::string& dbname) { + return Escape(dbname, true); +} + +std::string EncodeColName(const std::string& colname) { + return Escape(colname, false); +} + +std::string DecodeName(const std::string& name) { + char* decoded = new char[name.size() + 1]; + auto res = decode(name.c_str(), decoded); + if (res < 0) { + return ""; + } + std::string str = std::string{decoded}; + delete[] decoded; + return str; +} + +bool ShouldEscapeQuery(char c) { + char chars[] = "-[]{}()*+?\\.,^$|#"; + for (auto i = 0; i < strlen(chars); i++) { + if (c == chars[i]) { + return true; + } + }; + return false; +} + +std::string EscapeQuery(const std::string& s) { + auto count = 0; + for (auto i = 0; i < s.size(); i++) { + char c = s[i]; + if (ShouldEscapeQuery(c)) { + count++; + } + } + + if (count == 0) { + return s; + } + + std::string res; + res.reserve(s.size() + count); + for (auto i = 0; i < s.size(); i++) { + auto c = s[i]; + if (ShouldEscapeQuery(c)) { + res.push_back('\\'); + res.push_back(c); + } else { + res.push_back(c); + } + } + return res; + +} + +} diff --git a/common/cpp/src/database/encoding.h b/common/cpp/src/database/encoding.h new file mode 100644 index 0000000000000000000000000000000000000000..a0b5f47f64eaad4fc9d1e1c853bba4f7693dd460 --- /dev/null +++ b/common/cpp/src/database/encoding.h @@ -0,0 +1,14 @@ +#ifndef ASAPO_ENCODING_H +#define ASAPO_ENCODING_H + +#include <string> + +namespace asapo { + +std::string EncodeDbName(const std::string& dbname); +std::string EncodeColName(const std::string& colname); +std::string DecodeName(const std::string& name); +std::string EscapeQuery(const std::string& name); +} + +#endif //ASAPO_ENCODING_H diff --git a/common/cpp/src/database/mongodb_client.cpp b/common/cpp/src/database/mongodb_client.cpp index e4072d91c6e704ca1e56cb68a917fd12a0f46bfb..432982976c4ec6eaae9888ab6a299e685fe1e349 100644 --- a/common/cpp/src/database/mongodb_client.cpp +++ b/common/cpp/src/database/mongodb_client.cpp @@ -1,7 +1,9 @@ #include "asapo/json_parser/json_parser.h" #include "mongodb_client.h" +#include "encoding.h" #include <chrono> +#include <regex> #include "asapo/database/db_error.h" #include "asapo/common/data_structs.h" @@ -53,18 +55,26 @@ Error MongoDBClient::InitializeClient(const std::string& address) { } -void MongoDBClient::UpdateCurrentCollectionIfNeeded(const std::string& collection_name) const { +Error MongoDBClient::UpdateCurrentCollectionIfNeeded(const std::string& collection_name) const { if (collection_name == current_collection_name_) { - return; + return nullptr; } if (current_collection_ != nullptr) { mongoc_collection_destroy(current_collection_); } + auto encoded_name = EncodeColName(collection_name); + if (encoded_name.size() > maxCollectionNameLength) { + return DBErrorTemplates::kWrongInput.Generate("stream name too long"); + } + current_collection_ = mongoc_client_get_collection(client_, database_name_.c_str(), - collection_name.c_str()); + encoded_name.c_str()); current_collection_name_ = collection_name; mongoc_collection_set_write_concern(current_collection_, write_concern_); + + return nullptr; + } Error MongoDBClient::TryConnectDatabase() { @@ -82,10 +92,16 @@ Error MongoDBClient::Connect(const std::string& address, const std::string& data auto err = InitializeClient(address); if (err) { + CleanUp(); return err; } - database_name_ = std::move(database_name); + database_name_ = EncodeDbName(database_name); + + if (database_name_.size() > maxDbNameLength) { + CleanUp(); + return DBErrorTemplates::kWrongInput.Generate("data source name too long"); + } err = TryConnectDatabase(); if (err) { @@ -101,12 +117,15 @@ std::string MongoDBClient::DBAddress(const std::string& address) const { void MongoDBClient::CleanUp() { if (write_concern_) { mongoc_write_concern_destroy(write_concern_); + write_concern_ = nullptr; } if (current_collection_) { mongoc_collection_destroy(current_collection_); + current_collection_ = nullptr; } if (client_) { mongoc_client_destroy(client_); + client_ = nullptr; } } @@ -177,9 +196,11 @@ Error MongoDBClient::Insert(const std::string& collection, const MessageMeta& fi return DBErrorTemplates::kNotConnected.Generate(); } - UpdateCurrentCollectionIfNeeded(collection); + auto err = UpdateCurrentCollectionIfNeeded(collection); + if (err) { + return err; + } - Error err; auto document = PrepareBsonDocument(file, &err); if (err) { return err; @@ -189,9 +210,6 @@ Error MongoDBClient::Insert(const std::string& collection, const MessageMeta& fi } MongoDBClient::~MongoDBClient() { - if (!connected_) { - return; - } CleanUp(); } @@ -200,9 +218,11 @@ Error MongoDBClient::Upsert(const std::string& collection, uint64_t id, const ui return DBErrorTemplates::kNotConnected.Generate(); } - UpdateCurrentCollectionIfNeeded(collection); + auto err = UpdateCurrentCollectionIfNeeded(collection); + if (err) { + return err; + } - Error err; auto document = PrepareBsonDocument(data, (ssize_t) size, &err); if (err) { return err; @@ -244,9 +264,11 @@ Error MongoDBClient::InsertAsDatasetMessage(const std::string& collection, const return DBErrorTemplates::kNotConnected.Generate(); } - UpdateCurrentCollectionIfNeeded(collection); + auto err = UpdateCurrentCollectionIfNeeded(collection); + if (err) { + return err; + } - Error err; auto document = PrepareBsonDocument(file, &err); if (err) { return err; @@ -275,9 +297,11 @@ Error MongoDBClient::GetRecordFromDb(const std::string& collection, uint64_t id, return DBErrorTemplates::kNotConnected.Generate(); } - UpdateCurrentCollectionIfNeeded(collection); + auto err = UpdateCurrentCollectionIfNeeded(collection); + if (err) { + return err; + } - Error err; bson_error_t mongo_err; bson_t* filter; bson_t* opts; @@ -467,7 +491,12 @@ bool MongoCollectionIsDataStream(const std::string& stream_name) { return stream_name.rfind(prefix, 0) == 0; } -Error MongoDBClient::UpdateCurrentLastStreamInfo(const std::string& collection_name, StreamInfo* info) const { + +Error MongoDBClient::UpdateLastStreamInfo(const char* str, StreamInfo* info) const { + auto collection_name = DecodeName(str); + if (!MongoCollectionIsDataStream(collection_name)) { + return nullptr; + } StreamInfo next_info; auto err = GetStreamInfo(collection_name, &next_info); std::string prefix = std::string(kDBDataCollectionNamePrefix) + "_"; @@ -481,17 +510,6 @@ Error MongoDBClient::UpdateCurrentLastStreamInfo(const std::string& collection_n return nullptr; } -Error MongoDBClient::UpdateLastStreamInfo(const char* str, StreamInfo* info) const { - std::string collection_name{str}; - if (MongoCollectionIsDataStream(collection_name)) { - auto err = UpdateCurrentLastStreamInfo(collection_name, info); - if (err) { - return err; - } - } - return nullptr; -} - Error MongoDBClient::GetLastStream(StreamInfo* info) const { if (!connected_) { return DBErrorTemplates::kNotConnected.Generate(); @@ -521,6 +539,10 @@ Error MongoDBClient::GetLastStream(StreamInfo* info) const { err = DBErrorTemplates::kDBError.Generate(error.message); } + if (err != nullptr) { + info->name = DecodeName(info->name); + } + bson_destroy(opts); mongoc_database_destroy(database); return err; @@ -530,7 +552,7 @@ Error MongoDBClient::DeleteCollections(const std::string& prefix) const { mongoc_database_t* database; char** strv; bson_error_t error; - std::string querystr = "^" + prefix; + std::string querystr = "^" + EscapeQuery(prefix); bson_t* query = BCON_NEW ("name", BCON_REGEX(querystr.c_str(), "i")); bson_t* opts = BCON_NEW ("nameOnly", BCON_BOOL(true), "filter", BCON_DOCUMENT(query)); database = mongoc_client_get_database(client_, database_name_.c_str()); @@ -559,7 +581,7 @@ Error MongoDBClient::DeleteCollection(const std::string& name) const { mongoc_collection_destroy(collection); if (!r) { if (error.code == 26) { - return DBErrorTemplates::kNoRecord.Generate("collection " + name + " not found in " + database_name_); + return DBErrorTemplates::kNoRecord.Generate("collection " + name + " not found in " + DecodeName(database_name_)); } else { return DBErrorTemplates::kDBError.Generate(std::string(error.message) + ": " + std::to_string(error.code)); } @@ -582,15 +604,16 @@ Error MongoDBClient::DeleteDocumentsInCollection(const std::string& collection_n } Error MongoDBClient::DeleteStream(const std::string& stream) const { - std::string data_col = std::string(kDBDataCollectionNamePrefix) + "_" + stream; - std::string inprocess_col = "inprocess_" + stream; - std::string acks_col = "acks_" + stream; + auto stream_encoded = EncodeColName(stream); + std::string data_col = std::string(kDBDataCollectionNamePrefix) + "_" + stream_encoded; + std::string inprocess_col = "inprocess_" + stream_encoded; + std::string acks_col = "acks_" + stream_encoded; current_collection_name_ = ""; auto err = DeleteCollection(data_col); if (err == nullptr) { DeleteCollections(inprocess_col); DeleteCollections(acks_col); - std::string querystr = ".*_" + stream + "$"; + std::string querystr = ".*_" + EscapeQuery(stream_encoded) + "$"; DeleteDocumentsInCollection("current_location", querystr); } return err; diff --git a/common/cpp/src/database/mongodb_client.h b/common/cpp/src/database/mongodb_client.h index 443d73bd68dbc2184cf72272da63b1eca4413524..3999671fe73f18c5a26708a29a80834fb83f77e5 100644 --- a/common/cpp/src/database/mongodb_client.h +++ b/common/cpp/src/database/mongodb_client.h @@ -39,6 +39,9 @@ enum class GetRecordMode { kEarliest }; +const size_t maxDbNameLength = 63; +const size_t maxCollectionNameLength = 100; + class MongoDBClient final : public Database { public: MongoDBClient(); @@ -63,7 +66,7 @@ class MongoDBClient final : public Database { void CleanUp(); std::string DBAddress(const std::string& address) const; Error InitializeClient(const std::string& address); - void UpdateCurrentCollectionIfNeeded(const std::string& collection_name) const ; + Error UpdateCurrentCollectionIfNeeded(const std::string& collection_name) const ; Error Ping(); Error TryConnectDatabase(); Error InsertBsonDocument(const bson_p& document, bool ignore_duplicates) const; @@ -71,7 +74,6 @@ class MongoDBClient final : public Database { Error AddBsonDocumentToArray(bson_t* query, bson_t* update, bool ignore_duplicates) const; Error GetRecordFromDb(const std::string& collection, uint64_t id, GetRecordMode mode, std::string* res) const; Error UpdateLastStreamInfo(const char* str, StreamInfo* info) const; - Error UpdateCurrentLastStreamInfo(const std::string& collection_name, StreamInfo* info) const; Error DeleteCollection(const std::string& name) const; Error DeleteCollections(const std::string& prefix) const; Error DeleteDocumentsInCollection(const std::string& collection_name, const std::string& querystr) const; diff --git a/common/cpp/src/http_client/curl_http_client.cpp b/common/cpp/src/http_client/curl_http_client.cpp index 0bdefda2189197185f2c5e7797dc1b0f1a491ffa..0eedf7d1dd78ad36be3d50b360e39426eb9fd47a 100644 --- a/common/cpp/src/http_client/curl_http_client.cpp +++ b/common/cpp/src/http_client/curl_http_client.cpp @@ -218,5 +218,18 @@ CurlHttpClient::~CurlHttpClient() { curl_easy_cleanup(curl_); } } +std::string CurlHttpClient::UrlEscape(const std::string& uri) const noexcept { + if (!curl_) { + return ""; + } + char* output = curl_easy_escape(curl_, uri.c_str(), uri.size()); + if (output) { + auto res = std::string(output); + curl_free(output); + return res; + } else { + return ""; + } +} } diff --git a/common/cpp/src/http_client/curl_http_client.h b/common/cpp/src/http_client/curl_http_client.h index cfc2e7626974422a72d87b371660cd7052cc34ea..50ffc1cef90df7c511ecc37f1352bab3287baead 100644 --- a/common/cpp/src/http_client/curl_http_client.h +++ b/common/cpp/src/http_client/curl_http_client.h @@ -34,6 +34,8 @@ class CurlHttpClient final : public HttpClient { std::string Get(const std::string& uri, HttpCode* response_code, Error* err) const noexcept override; std::string Post(const std::string& uri, const std::string& cookie, const std::string& data, HttpCode* response_code, Error* err) const noexcept override; + std::string UrlEscape(const std::string& uri) const noexcept override; + Error Post(const std::string& uri, const std::string& cookie, const std::string& input_data, MessageData* output_data, uint64_t output_data_size, HttpCode* response_code) const noexcept override; diff --git a/common/cpp/unittests/database/test_encoding.cpp b/common/cpp/unittests/database/test_encoding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8cd7ad3f1967f174f95f9731a9d83381a4684f8 --- /dev/null +++ b/common/cpp/unittests/database/test_encoding.cpp @@ -0,0 +1,50 @@ +#include "../../src/database/encoding.h" +#include "asapo/preprocessor/definitions.h" +#include <gmock/gmock.h> +#include "gtest/gtest.h" +#include <chrono> + +using ::testing::AtLeast; +using ::testing::Eq; +using ::testing::Ne; +using ::testing::Test; +using ::testing::_; +using ::testing::Mock; +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::SetArgPointee; + + +namespace { + +uint64_t big_uint = 18446744073709551615ull; + +TEST(EncodingTests, EncodeDbName) { + + std::string dbname = R"(db_/\."$)"; + std::string dbname_encoded = "db_%2F%5C%2E%22%24"; + + auto encoded = asapo::EncodeDbName(dbname); + auto decoded = asapo::DecodeName(encoded); + + ASSERT_THAT(encoded, Eq(dbname_encoded)); + ASSERT_THAT(decoded, Eq(dbname)); +} + +TEST(EncodingTests, EncodeColName) { + + std::string colname = R"(col_/\."$)"; + std::string colname_encoded = R"(col_/\."%24)"; + + auto encoded = asapo::EncodeColName(colname); + + auto decoded = asapo::DecodeName(encoded); + + + ASSERT_THAT(encoded, Eq(colname_encoded)); + ASSERT_THAT(decoded, Eq(colname)); +} + + + +} diff --git a/common/go/src/asapo_common/utils/routes.go b/common/go/src/asapo_common/utils/routes.go index 598c443d92ce204243904ccd8bc9201b7ad94014..83b78f4180b6ac6e31137e3e7fcbb9c0f6f1845b 100644 --- a/common/go/src/asapo_common/utils/routes.go +++ b/common/go/src/asapo_common/utils/routes.go @@ -16,7 +16,7 @@ type Route struct { } func NewRouter(listRoutes Routes) *mux.Router { - router := mux.NewRouter() + router := mux.NewRouter().UseEncodedPath() for _, route := range listRoutes { router. Methods(route.Method). diff --git a/consumer/api/cpp/src/consumer_impl.cpp b/consumer/api/cpp/src/consumer_impl.cpp index 0efcc6cb3e6d25c224cf182d3221a3f38b2ac6ad..00ca9b96158f167159ff5817927a059a2f862fec 100644 --- a/consumer/api/cpp/src/consumer_impl.cpp +++ b/consumer/api/cpp/src/consumer_impl.cpp @@ -141,7 +141,7 @@ ConsumerImpl::ConsumerImpl(std::string server_uri, if (source_credentials_.data_source.empty()) { source_credentials_.data_source = SourceCredentials::kDefaultDataSource; } - + data_source_encoded_ = httpclient__->UrlEscape(source_credentials_.data_source); } void ConsumerImpl::SetTimeout(uint64_t timeout_ms) { @@ -291,10 +291,8 @@ Error ConsumerImpl::GetRecordFromServer(std::string* response, std::string group interrupt_flag_ = false; std::string request_suffix = OpToUriCmd(op); std::string request_group = OpToUriCmd(op); - std::string - request_api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source - + "/" + std::move(stream); + + std::string request_api = UriPrefix(std::move(stream), "", ""); uint64_t elapsed_ms = 0; Error no_data_error; while (true) { @@ -305,7 +303,8 @@ Error ConsumerImpl::GetRecordFromServer(std::string* response, std::string group auto start = system_clock::now(); auto err = DiscoverService(kBrokerServiceName, ¤t_broker_uri_); if (err == nullptr) { - auto ri = PrepareRequestInfo(request_api + "/" + group_id + "/" + request_suffix, dataset, min_size); + auto ri = PrepareRequestInfo(request_api + "/" + httpclient__->UrlEscape(group_id) + "/" + request_suffix, dataset, + min_size); if (request_suffix == "next" && resend_) { ri.extra_params = ri.extra_params + "&resend_nacks=true" + "&delay_ms=" + std::to_string(delay_ms_) + "&resend_attempts=" + std::to_string(resend_attempts_); @@ -577,9 +576,8 @@ Error ConsumerImpl::ResetLastReadMarker(std::string group_id, std::string stream Error ConsumerImpl::SetLastReadMarker(std::string group_id, uint64_t value, std::string stream) { RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + "/" - + std::move(stream) + "/" + std::move(group_id) + "/resetcounter"; + ri.api = UriPrefix(std::move(stream), std::move(group_id), "resetcounter"); + ri.extra_params = "&value=" + std::to_string(value); ri.post = true; @@ -608,11 +606,9 @@ Error ConsumerImpl::GetRecordFromServerById(uint64_t id, std::string* response, } RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - +"/" + std::move(stream) + - "/" + std::move( - group_id) + "/" + std::to_string(id); + ri.api = UriPrefix(std::move(stream), std::move(group_id), std::to_string(id)); + + if (dataset) { ri.extra_params += "&dataset=true"; ri.extra_params += "&minsize=" + std::to_string(min_size); @@ -625,9 +621,7 @@ Error ConsumerImpl::GetRecordFromServerById(uint64_t id, std::string* response, std::string ConsumerImpl::GetBeamtimeMeta(Error* err) { RequestInfo ri; - ri.api = - "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + "/default/0/meta/0"; + ri.api = UriPrefix("default", "0", "meta/0"); return BrokerRequestWithTimeout(ri, err); } @@ -649,9 +643,8 @@ MessageMetas ConsumerImpl::QueryMessages(std::string query, std::string stream, } RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - "/" + std::move(stream) + "/0/querymessages"; + ri.api = UriPrefix(std::move(stream), "0", "querymessages"); + ri.post = true; ri.body = std::move(query); @@ -748,11 +741,10 @@ StreamInfos ConsumerImpl::GetStreamList(std::string from, StreamFilter filter, E RequestInfo ConsumerImpl::GetStreamListRequest(const std::string& from, const StreamFilter& filter) const { RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + "/0/streams"; + ri.api = UriPrefix("0", "", "streams"); ri.post = false; if (!from.empty()) { - ri.extra_params = "&from=" + from; + ri.extra_params = "&from=" + httpclient__->UrlEscape(from); } ri.extra_params += "&filter=" + filterToString(filter); return ri; @@ -817,10 +809,7 @@ Error ConsumerImpl::Acknowledge(std::string group_id, uint64_t id, std::string s return ConsumerErrorTemplates::kWrongInput.Generate("empty stream"); } RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - +"/" + std::move(stream) + - "/" + std::move(group_id) + "/" + std::to_string(id); + ri.api = UriPrefix(std::move(stream), std::move(group_id), std::to_string(id)); ri.post = true; ri.body = "{\"Op\":\"ackmessage\"}"; @@ -839,10 +828,7 @@ IdList ConsumerImpl::GetUnacknowledgedMessages(std::string group_id, return {}; } RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - +"/" + std::move(stream) + - "/" + std::move(group_id) + "/nacks"; + ri.api = UriPrefix(std::move(stream), std::move(group_id), "nacks"); ri.extra_params = "&from=" + std::to_string(from_id) + "&to=" + std::to_string(to_id); auto json_string = BrokerRequestWithTimeout(ri, error); @@ -865,10 +851,7 @@ uint64_t ConsumerImpl::GetLastAcknowledgedMessage(std::string group_id, std::str return 0; } RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - +"/" + std::move(stream) + - "/" + std::move(group_id) + "/lastack"; + ri.api = UriPrefix(std::move(stream), std::move(group_id), "lastack"); auto json_string = BrokerRequestWithTimeout(ri, error); if (*error) { @@ -901,10 +884,7 @@ Error ConsumerImpl::NegativeAcknowledge(std::string group_id, return ConsumerErrorTemplates::kWrongInput.Generate("empty stream"); } RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - +"/" + std::move(stream) + - "/" + std::move(group_id) + "/" + std::to_string(id); + ri.api = UriPrefix(std::move(stream), std::move(group_id), std::to_string(id)); ri.post = true; ri.body = R"({"Op":"negackmessage","Params":{"DelayMs":)" + std::to_string(delay_ms) + "}}"; @@ -946,9 +926,7 @@ uint64_t ConsumerImpl::ParseGetCurrentCountResponce(Error* err, const std::strin RequestInfo ConsumerImpl::GetSizeRequestForSingleMessagesStream(std::string& stream) const { RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - +"/" + std::move(stream) + "/size"; + ri.api = UriPrefix(std::move(stream), "", "size"); return ri; } @@ -988,10 +966,7 @@ Error ConsumerImpl::GetVersionInfo(std::string* client_info, std::string* server RequestInfo ConsumerImpl::GetDeleteStreamRequest(std::string stream, DeleteStreamOptions options) const { RequestInfo ri; - ri.api = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" - + source_credentials_.data_source + - +"/" + std::move(stream) + - "/delete"; + ri.api = UriPrefix(std::move(stream), "", "delete"); ri.post = true; ri.body = options.Json(); return ri; @@ -1004,4 +979,20 @@ Error ConsumerImpl::DeleteStream(std::string stream, DeleteStreamOptions options return err; } +std::string ConsumerImpl::UriPrefix( std::string stream, std::string group, std::string suffix) const { + auto stream_encoded = httpclient__->UrlEscape(std::move(stream)); + auto group_encoded = group.size() > 0 ? httpclient__->UrlEscape(std::move(group)) : ""; + auto uri = "/" + kConsumerProtocol.GetBrokerVersion() + "/beamtime/" + source_credentials_.beamtime_id + "/" + + data_source_encoded_ + "/" + stream_encoded; + if (group_encoded.size() > 0) { + uri = uri + "/" + group_encoded; + } + if (suffix.size() > 0) { + uri = uri + "/" + suffix; + } + + return uri; + +} + } \ No newline at end of file diff --git a/consumer/api/cpp/src/consumer_impl.h b/consumer/api/cpp/src/consumer_impl.h index dd2a2635fbc26b776ab66cb6070a96c1da9abe3c..2804309ae2446656e4c49224b8f917ba695c526a 100644 --- a/consumer/api/cpp/src/consumer_impl.h +++ b/consumer/api/cpp/src/consumer_impl.h @@ -150,6 +150,7 @@ class ConsumerImpl final : public asapo::Consumer { uint64_t GetCurrentCount(std::string stream, const RequestInfo& ri, Error* err); RequestInfo GetStreamListRequest(const std::string& from, const StreamFilter& filter) const; Error GetServerVersionInfo(std::string* server_info, bool* supported) ; + std::string UriPrefix( std::string stream, std::string group, std::string suffix) const; std::string endpoint_; std::string current_broker_uri_; @@ -157,6 +158,7 @@ class ConsumerImpl final : public asapo::Consumer { std::string source_path_; bool has_filesystem_; SourceCredentials source_credentials_; + std::string data_source_encoded_; uint64_t timeout_ms_ = 0; bool should_try_rdma_first_ = true; NetworkConnectionType current_connection_type_ = NetworkConnectionType::kUndefined; diff --git a/consumer/api/cpp/unittests/test_consumer_impl.cpp b/consumer/api/cpp/unittests/test_consumer_impl.cpp index aca44d359c39a3cb3dbc8ad5632cbca03cef8921..862f2f64e3c77921ca8680bf24c55838f9e4c773 100644 --- a/consumer/api/cpp/unittests/test_consumer_impl.cpp +++ b/consumer/api/cpp/unittests/test_consumer_impl.cpp @@ -79,9 +79,13 @@ class ConsumerImplTests : public Test { std::string expected_path = "/tmp/beamline/beamtime"; std::string expected_filename = "filename"; std::string expected_full_path = std::string("/tmp/beamline/beamtime") + asapo::kPathSeparator + expected_filename; - std::string expected_group_id = "groupid"; - std::string expected_data_source = "source"; - std::string expected_stream = "stream"; + std::string expected_group_id = "groupid$"; + std::string expected_data_source = "source/$.?"; + std::string expected_stream = "str $ eam$"; + std::string expected_group_id_encoded = "groupid%24"; + std::string expected_data_source_encoded = "source%2F%24.%3F"; + std::string expected_stream_encoded = "str%20%24%20eam%24"; + std::string expected_metadata = "{\"meta\":1}"; std::string expected_query_string = "bla"; std::string expected_folder_token = "folder_token"; @@ -118,6 +122,13 @@ class ConsumerImplTests : public Test { fts_consumer->io__ = std::unique_ptr<IO> {&mock_io}; fts_consumer->httpclient__ = std::unique_ptr<asapo::HttpClient> {&mock_http_client}; fts_consumer->net_client__ = std::unique_ptr<asapo::NetClient> {&mock_netclient}; + ON_CALL(mock_http_client, UrlEscape_t(expected_stream)).WillByDefault(Return(expected_stream_encoded)); + ON_CALL(mock_http_client, UrlEscape_t(expected_group_id)).WillByDefault(Return(expected_group_id_encoded)); + ON_CALL(mock_http_client, UrlEscape_t(expected_data_source)).WillByDefault(Return(expected_data_source_encoded)); + ON_CALL(mock_http_client, UrlEscape_t("0")).WillByDefault(Return("0")); + ON_CALL(mock_http_client, UrlEscape_t("")).WillByDefault(Return("")); + ON_CALL(mock_http_client, UrlEscape_t("default")).WillByDefault(Return("default")); + ON_CALL(mock_http_client, UrlEscape_t("stream")).WillByDefault(Return("stream")); } void TearDown() override { @@ -217,7 +228,7 @@ TEST_F(ConsumerImplTests, DefaultStreamIsDetector) { MockGetBrokerUri(); EXPECT_CALL(mock_http_client, - Get_t(expected_broker_api + "/beamtime/beamtime_id/detector/stream/" + expected_group_id + Get_t(expected_broker_api + "/beamtime/beamtime_id/detector/stream/" + expected_group_id_encoded + "/next?token=" + expected_token, _, @@ -226,14 +237,15 @@ TEST_F(ConsumerImplTests, DefaultStreamIsDetector) { SetArgPointee<2>(nullptr), Return(""))); - consumer->GetNext(expected_group_id, &info, nullptr, expected_stream); + consumer->GetNext(expected_group_id, &info, nullptr, "stream"); } TEST_F(ConsumerImplTests, GetNextUsesCorrectUriWithStream) { MockGetBrokerUri(); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/" + expected_group_id + "/next?token=" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + + expected_stream_encoded + "/" + expected_group_id_encoded + "/next?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), @@ -246,7 +258,8 @@ TEST_F(ConsumerImplTests, GetLastUsesCorrectUri) { MockGetBrokerUri(); EXPECT_CALL(mock_http_client, - Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + expected_stream + "/0/last?token=" + Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + expected_stream_encoded + + "/0/last?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), @@ -429,7 +442,8 @@ TEST_F(ConsumerImplTests, GetMessageReturnsNoDataAfterTimeoutEvenIfOtherErrorOcc Return("{\"op\":\"get_record_by_id\",\"id\":" + std::to_string(expected_dataset_id) + ",\"id_max\":2,\"next_stream\":\"""\"}"))); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/0/" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + + "/stream/0/" + std::to_string(expected_dataset_id) + "?token=" + expected_token, _, _)).Times(AtLeast(1)).WillRepeatedly(DoAll( SetArgPointee<1>(HttpCode::NotFound), @@ -437,7 +451,7 @@ TEST_F(ConsumerImplTests, GetMessageReturnsNoDataAfterTimeoutEvenIfOtherErrorOcc Return(""))); consumer->SetTimeout(300); - auto err = consumer->GetNext(expected_group_id, &info, nullptr, expected_stream); + auto err = consumer->GetNext(expected_group_id, &info, nullptr, "stream"); ASSERT_THAT(err, Eq(asapo::ConsumerErrorTemplates::kNoData)); } @@ -630,13 +644,13 @@ TEST_F(ConsumerImplTests, ResetCounterByDefaultUsesCorrectUri) { consumer->SetTimeout(100); EXPECT_CALL(mock_http_client, - Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/" + - expected_group_id + + Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/stream/" + + expected_group_id_encoded + "/resetcounter?token=" + expected_token + "&value=0", _, _, _, _)).WillOnce(DoAll( SetArgPointee<3>(HttpCode::OK), SetArgPointee<4>(nullptr), Return(""))); - auto err = consumer->ResetLastReadMarker(expected_group_id, expected_stream); + auto err = consumer->ResetLastReadMarker(expected_group_id, "stream"); ASSERT_THAT(err, Eq(nullptr)); } @@ -644,9 +658,10 @@ TEST_F(ConsumerImplTests, ResetCounterUsesCorrectUri) { MockGetBrokerUri(); consumer->SetTimeout(100); - EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/" + - expected_group_id + + EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + + expected_stream_encoded + "/" + + expected_group_id_encoded + "/resetcounter?token=" + expected_token + "&value=10", _, _, _, _)).WillOnce(DoAll( SetArgPointee<3>(HttpCode::OK), SetArgPointee<4>(nullptr), @@ -659,8 +674,9 @@ TEST_F(ConsumerImplTests, GetCurrentSizeUsesCorrectUri) { MockGetBrokerUri(); consumer->SetTimeout(100); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/size?token=" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + + expected_stream_encoded + "/size?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), SetArgPointee<2>(nullptr), @@ -675,8 +691,8 @@ TEST_F(ConsumerImplTests, GetCurrentSizeErrorOnWrongResponce) { MockGetBrokerUri(); consumer->SetTimeout(100); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + - "/" + expected_stream + "/size?token=" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + + "/" + expected_stream_encoded + "/size?token=" + expected_token, _, _)).WillRepeatedly(DoAll( SetArgPointee<1>(HttpCode::Unauthorized), SetArgPointee<2>(nullptr), @@ -691,14 +707,14 @@ TEST_F(ConsumerImplTests, GetNDataErrorOnWrongParse) { MockGetBrokerUri(); consumer->SetTimeout(100); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/stream/size?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), SetArgPointee<2>(nullptr), Return("{\"siz\":10}"))); asapo::Error err; - auto size = consumer->GetCurrentSize(expected_stream, &err); + auto size = consumer->GetCurrentSize("stream", &err); ASSERT_THAT(err, Ne(nullptr)); ASSERT_THAT(size, Eq(0)); } @@ -709,7 +725,8 @@ TEST_F(ConsumerImplTests, GetByIdUsesCorrectUri) { auto to_send = CreateFI(); auto json = to_send.Json(); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/0/" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + + "/stream/0/" + std::to_string( expected_dataset_id) + "?token=" + expected_token, _, @@ -718,7 +735,7 @@ TEST_F(ConsumerImplTests, GetByIdUsesCorrectUri) { SetArgPointee<2>(nullptr), Return(json))); - auto err = consumer->GetById(expected_dataset_id, &info, nullptr, expected_stream); + auto err = consumer->GetById(expected_dataset_id, &info, nullptr, "stream"); ASSERT_THAT(err, Eq(nullptr)); ASSERT_THAT(info.name, Eq(to_send.name)); @@ -728,14 +745,15 @@ TEST_F(ConsumerImplTests, GetByIdTimeouts) { MockGetBrokerUri(); consumer->SetTimeout(10); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/0/" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + + "/stream/0/" + std::to_string(expected_dataset_id) + "?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::Conflict), SetArgPointee<2>(nullptr), Return(""))); - auto err = consumer->GetById(expected_dataset_id, &info, nullptr, expected_stream); + auto err = consumer->GetById(expected_dataset_id, &info, nullptr, "stream"); ASSERT_THAT(err, Eq(asapo::ConsumerErrorTemplates::kNoData)); } @@ -744,14 +762,15 @@ TEST_F(ConsumerImplTests, GetByIdReturnsEndOfStream) { MockGetBrokerUri(); consumer->SetTimeout(10); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/0/" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + + "/stream/0/" + std::to_string(expected_dataset_id) + "?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::Conflict), SetArgPointee<2>(nullptr), Return("{\"op\":\"get_record_by_id\",\"id\":1,\"id_max\":1,\"next_stream\":\"""\"}"))); - auto err = consumer->GetById(expected_dataset_id, &info, nullptr, expected_stream); + auto err = consumer->GetById(expected_dataset_id, &info, nullptr, "stream"); ASSERT_THAT(err, Eq(asapo::ConsumerErrorTemplates::kEndOfStream)); } @@ -760,14 +779,15 @@ TEST_F(ConsumerImplTests, GetByIdReturnsEndOfStreamWhenIdTooLarge) { MockGetBrokerUri(); consumer->SetTimeout(10); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/0/" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + + "/stream/0/" + std::to_string(expected_dataset_id) + "?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::Conflict), SetArgPointee<2>(nullptr), Return("{\"op\":\"get_record_by_id\",\"id\":100,\"id_max\":1,\"next_stream\":\"""\"}"))); - auto err = consumer->GetById(expected_dataset_id, &info, nullptr, expected_stream); + auto err = consumer->GetById(expected_dataset_id, &info, nullptr, "stream"); ASSERT_THAT(err, Eq(asapo::ConsumerErrorTemplates::kEndOfStream)); } @@ -776,7 +796,7 @@ TEST_F(ConsumerImplTests, GetMetaDataOK) { MockGetBrokerUri(); consumer->SetTimeout(100); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/default/0/meta/0?token=" + expected_token, _, _)).WillOnce(DoAll( @@ -882,7 +902,7 @@ TEST_F(ConsumerImplTests, QueryMessagesReturnRecords) { auto responce_string = "[" + json1 + "," + json2 + "]"; EXPECT_CALL(mock_http_client, - Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/0" + + Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/stream/0" + "/querymessages?token=" + expected_token, _, expected_query_string, _, _)).WillOnce(DoAll( SetArgPointee<3>(HttpCode::OK), SetArgPointee<4>(nullptr), @@ -890,7 +910,7 @@ TEST_F(ConsumerImplTests, QueryMessagesReturnRecords) { consumer->SetTimeout(100); asapo::Error err; - auto messages = consumer->QueryMessages(expected_query_string, expected_stream, &err); + auto messages = consumer->QueryMessages(expected_query_string, "stream", &err); ASSERT_THAT(err, Eq(nullptr)); ASSERT_THAT(messages.size(), Eq(2)); @@ -902,15 +922,16 @@ TEST_F(ConsumerImplTests, QueryMessagesReturnRecords) { TEST_F(ConsumerImplTests, GetNextDatasetUsesCorrectUri) { MockGetBrokerUri(); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/" + - expected_group_id + "/next?token=" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + + "/stream/" + + expected_group_id_encoded + "/next?token=" + expected_token + "&dataset=true&minsize=0", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), SetArgPointee<2>(nullptr), Return(""))); asapo::Error err; - consumer->GetNextDataset(expected_group_id, 0, expected_stream, &err); + consumer->GetNextDataset(expected_group_id, 0, "stream", &err); } TEST_F(ConsumerImplTests, GetNextErrorOnEmptyStream) { @@ -1037,8 +1058,9 @@ TEST_F(ConsumerImplTests, GetDataSetReturnsParseError) { TEST_F(ConsumerImplTests, GetLastDatasetUsesCorrectUri) { MockGetBrokerUri(); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/0/last?token=" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + + expected_stream_encoded + "/0/last?token=" + expected_token + "&dataset=true&minsize=1", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), @@ -1051,7 +1073,8 @@ TEST_F(ConsumerImplTests, GetLastDatasetUsesCorrectUri) { TEST_F(ConsumerImplTests, GetDatasetByIdUsesCorrectUri) { MockGetBrokerUri(); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/0/" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + + "/stream/0/" + std::to_string(expected_dataset_id) + "?token=" + expected_token + "&dataset=true" + "&minsize=0", _, _)).WillOnce(DoAll( @@ -1059,14 +1082,14 @@ TEST_F(ConsumerImplTests, GetDatasetByIdUsesCorrectUri) { SetArgPointee<2>(nullptr), Return(""))); asapo::Error err; - consumer->GetDatasetById(expected_dataset_id, 0, expected_stream, &err); + consumer->GetDatasetById(expected_dataset_id, 0, "stream", &err); } TEST_F(ConsumerImplTests, DeleteStreamUsesCorrectUri) { MockGetBrokerUri(); std::string expected_delete_stream_query_string = "{\"ErrorOnNotExist\":true,\"DeleteMeta\":true}"; - EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/delete" + EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + expected_stream_encoded + "/delete" + "?token=" + expected_token, _, expected_delete_stream_query_string, _, _)).WillOnce(DoAll( SetArgPointee<3>(HttpCode::OK), @@ -1090,15 +1113,15 @@ TEST_F(ConsumerImplTests, GetStreamListUsesCorrectUri) { + R"({"lastId":124,"name":"test1","timestampCreated":2000000,"timestampLast":2000,"finished":true,"nextStream":"next"}]})"; EXPECT_CALL(mock_http_client, - Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/0/streams" - + "?token=" + expected_token + "&from=stream_from&filter=all", _, + Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/0/streams" + + "?token=" + expected_token + "&from=" + expected_stream_encoded + "&filter=all", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), SetArgPointee<2>(nullptr), Return(return_streams))); asapo::Error err; - auto streams = consumer->GetStreamList("stream_from", asapo::StreamFilter::kAllStreams, &err); + auto streams = consumer->GetStreamList(expected_stream, asapo::StreamFilter::kAllStreams, &err); ASSERT_THAT(err, Eq(nullptr)); ASSERT_THAT(streams.size(), Eq(2)); ASSERT_THAT(streams.size(), 2); @@ -1111,7 +1134,7 @@ TEST_F(ConsumerImplTests, GetStreamListUsesCorrectUri) { TEST_F(ConsumerImplTests, GetStreamListUsesCorrectUriWithoutFrom) { MockGetBrokerUri(); EXPECT_CALL(mock_http_client, - Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/0/streams" + Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/0/streams" + "?token=" + expected_token + "&filter=finished", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), @@ -1253,9 +1276,10 @@ TEST_F(ConsumerImplTests, GetMessageTriesToGetTokenAgainIfTransferFailed) { TEST_F(ConsumerImplTests, AcknowledgeUsesCorrectUri) { MockGetBrokerUri(); auto expected_acknowledge_command = "{\"Op\":\"ackmessage\"}"; - EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/" + - expected_group_id + EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + + expected_stream_encoded + "/" + + expected_group_id_encoded + "/" + std::to_string(expected_dataset_id) + "?token=" + expected_token, _, expected_acknowledge_command, _, _)).WillOnce(DoAll( SetArgPointee<3>(HttpCode::OK), @@ -1269,9 +1293,10 @@ TEST_F(ConsumerImplTests, AcknowledgeUsesCorrectUri) { void ConsumerImplTests::ExpectIdList(bool error) { MockGetBrokerUri(); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/" + - expected_group_id + "/nacks?token=" + expected_token + "&from=1&to=0", _, _)).WillOnce(DoAll( + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + + expected_stream_encoded + "/" + + expected_group_id_encoded + "/nacks?token=" + expected_token + "&from=1&to=0", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), SetArgPointee<2>(nullptr), Return(error ? "" : "{\"unacknowledged\":[1,2,3]}"))); @@ -1287,9 +1312,10 @@ TEST_F(ConsumerImplTests, GetUnAcknowledgedListReturnsIds) { } void ConsumerImplTests::ExpectLastAckId(bool empty_response) { - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/" + - expected_group_id + "/lastack?token=" + expected_token, _, _)).WillOnce(DoAll( + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + + expected_stream_encoded + "/" + + expected_group_id_encoded + "/lastack?token=" + expected_token, _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), SetArgPointee<2>(nullptr), Return(empty_response ? "{\"lastAckId\":0}" : "{\"lastAckId\":1}"))); @@ -1325,8 +1351,9 @@ TEST_F(ConsumerImplTests, GetByIdErrorsForId0) { TEST_F(ConsumerImplTests, ResendNacks) { MockGetBrokerUri(); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/stream/" - + expected_group_id + "/next?token=" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + + "/stream/" + + expected_group_id_encoded + "/next?token=" + expected_token + "&resend_nacks=true&delay_ms=10000&resend_attempts=3", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), @@ -1334,15 +1361,16 @@ TEST_F(ConsumerImplTests, ResendNacks) { Return(""))); consumer->SetResendNacs(true, 10000, 3); - consumer->GetNext(expected_group_id, &info, nullptr, expected_stream); + consumer->GetNext(expected_group_id, &info, nullptr, "stream"); } TEST_F(ConsumerImplTests, NegativeAcknowledgeUsesCorrectUri) { MockGetBrokerUri(); auto expected_neg_acknowledge_command = R"({"Op":"negackmessage","Params":{"DelayMs":10000}})"; - EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/" + - expected_group_id + EXPECT_CALL(mock_http_client, Post_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + + expected_stream_encoded + "/" + + expected_group_id_encoded + "/" + std::to_string(expected_dataset_id) + "?token=" + expected_token, _, expected_neg_acknowledge_command, _, _)).WillOnce( DoAll( @@ -1386,8 +1414,9 @@ TEST_F(ConsumerImplTests, GetCurrentDataSetCounteUsesCorrectUri) { MockGetBrokerUri(); consumer->SetTimeout(100); - EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source + "/" + - expected_stream + "/size?token=" + EXPECT_CALL(mock_http_client, Get_t(expected_broker_api + "/beamtime/beamtime_id/" + expected_data_source_encoded + "/" + + + expected_stream_encoded + "/size?token=" + expected_token + "&incomplete=true", _, _)).WillOnce(DoAll( SetArgPointee<1>(HttpCode::OK), SetArgPointee<2>(nullptr), diff --git a/discovery/src/asapo_discovery/protocols/hard_coded_consumer.go b/discovery/src/asapo_discovery/protocols/hard_coded_consumer.go index 685bbd4ed391022341ea262edab0a0ddfcd19abb..1410b30a9c767aef2c441de9cb8e7bb76ef76e26 100644 --- a/discovery/src/asapo_discovery/protocols/hard_coded_consumer.go +++ b/discovery/src/asapo_discovery/protocols/hard_coded_consumer.go @@ -27,7 +27,7 @@ func GetSupportedConsumerProtocols() []Protocol { "Broker": "v0.3", "File Transfer": "v0.1", "Data cache service": "v0.1", - }, &protocolValidatorDeprecated{getTimefromDate("2022-06-01")}}, + }, &protocolValidatorDeprecated{getTimefromDate("2022-07-01")}}, Protocol{"v0.2", map[string]string{ "Discovery": "v0.1", diff --git a/discovery/src/asapo_discovery/protocols/hard_coded_producer.go b/discovery/src/asapo_discovery/protocols/hard_coded_producer.go index a361478ccf442da1badfae0bc56c87b61164fc0f..7207d0547d80767119b324b53e1ad2dd33d82d16 100644 --- a/discovery/src/asapo_discovery/protocols/hard_coded_producer.go +++ b/discovery/src/asapo_discovery/protocols/hard_coded_producer.go @@ -2,11 +2,16 @@ package protocols func GetSupportedProducerProtocols() []Protocol { return []Protocol{ + Protocol{"v0.3", + map[string]string{ + "Discovery": "v0.1", + "Receiver": "v0.3", + }, &protocolValidatorCurrent{}}, Protocol{"v0.2", map[string]string{ "Discovery": "v0.1", "Receiver": "v0.2", - }, &protocolValidatorCurrent{}}, + }, &protocolValidatorDeprecated{getTimefromDate("2022-07-01")}}, Protocol{"v0.1", map[string]string{ "Discovery": "v0.1", diff --git a/discovery/src/asapo_discovery/protocols/protocol_test.go b/discovery/src/asapo_discovery/protocols/protocol_test.go index 4f0e45f869e0003634a5616d0276ff65725eb15e..344699581b2a09b20a7db3f72170dd4ca79d953b 100644 --- a/discovery/src/asapo_discovery/protocols/protocol_test.go +++ b/discovery/src/asapo_discovery/protocols/protocol_test.go @@ -23,7 +23,8 @@ var protocolTests = []protocolTest{ // producer - {"producer", "v0.2", true, "current", "v0.2"}, + {"producer", "v0.3", true, "current", "v0.3"}, + {"producer", "v0.2", true, "deprecates", "v0.2"}, {"producer", "v0.1", true, "deprecates", "v0.1"}, {"producer", "v1000.2", false, "unknown", "unknown protocol"}, } diff --git a/examples/pipeline/in_to_out_python/check_linux.sh b/examples/pipeline/in_to_out_python/check_linux.sh index 4a1a7438585878bf130886ffa9405174ef151ab6..04c4fef60afd82c449055f992e913efaa46f60c5 100644 --- a/examples/pipeline/in_to_out_python/check_linux.sh +++ b/examples/pipeline/in_to_out_python/check_linux.sh @@ -2,15 +2,15 @@ source_path=. beamtime_id=asapo_test -data_source_in=detector -data_source_out=data_source +data_source_in=detector/123 +data_source_out=data_source/12.4 timeout=15 timeout_producer=25 nthreads=4 -indatabase_name=${beamtime_id}_${data_source_in} -outdatabase_name=${beamtime_id}_${data_source_out} +indatabase_name=${beamtime_id}_detector%2F123 +outdatabase_name=${beamtime_id}_data_source%2F12%2E4 #asapo_test read token token=$ASAPO_TEST_RW_TOKEN @@ -33,7 +33,7 @@ Cleanup() { echo "db.dropDatabase()" | mongo ${outdatabase_name} rm -rf processed rm -rf ${receiver_root_folder} - rm -rf out +# rm -rf out } diff --git a/examples/pipeline/in_to_out_python/in_to_out.py b/examples/pipeline/in_to_out_python/in_to_out.py index 0e58c1b0a0daf249ca960344e3088b177de570dd..93e7328e040c8f6eceb1a37dcc555b8c6f6702c7 100644 --- a/examples/pipeline/in_to_out_python/in_to_out.py +++ b/examples/pipeline/in_to_out_python/in_to_out.py @@ -7,6 +7,8 @@ import threading lock = threading.Lock() +print (asapo_consumer.__version__) +print (asapo_producer.__version__) n_send = 0 diff --git a/producer/api/cpp/src/producer_impl.cpp b/producer/api/cpp/src/producer_impl.cpp index e29dfec2e38de5a6a0d6b20db3370b43689f3deb..f13f710200bee38fe466f1452011664b21e8b11c 100644 --- a/producer/api/cpp/src/producer_impl.cpp +++ b/producer/api/cpp/src/producer_impl.cpp @@ -12,16 +12,17 @@ #include "asapo/http_client/http_client.h" #include "asapo/common/internal/version.h" -namespace asapo { +namespace asapo { const size_t ProducerImpl::kDiscoveryServiceUpdateFrequencyMs = 10000; // 10s ProducerImpl::ProducerImpl(std::string endpoint, uint8_t n_processing_threads, uint64_t timeout_ms, - asapo::RequestHandlerType type): + asapo::RequestHandlerType type) : log__{GetDefaultProducerLogger()}, httpclient__{DefaultHttpClient()}, timeout_ms_{timeout_ms}, endpoint_{endpoint} { switch (type) { case RequestHandlerType::kTcp: - discovery_service_.reset(new ReceiverDiscoveryService{endpoint, ProducerImpl::kDiscoveryServiceUpdateFrequencyMs}); + discovery_service_.reset(new ReceiverDiscoveryService{endpoint, + ProducerImpl::kDiscoveryServiceUpdateFrequencyMs}); request_handler_factory_.reset(new ProducerRequestHandlerFactory{discovery_service_.get()}); break; case RequestHandlerType::kFilesystem: @@ -78,7 +79,7 @@ Error CheckProducerRequest(const MessageHeader& message_header, uint64_t ingest_ return ProducerErrorTemplates::kWrongInput.Generate("too long filename"); } - if (message_header.file_name.empty() ) { + if (message_header.file_name.empty()) { return ProducerErrorTemplates::kWrongInput.Generate("empty filename"); } @@ -150,8 +151,11 @@ Error ProducerImpl::Send(const MessageHeader& message_header, auto request_header = GenerateNextSendRequest(message_header, std::move(stream), ingest_mode); - err = request_pool__->AddRequest(std::unique_ptr<ProducerRequest> {new ProducerRequest{source_cred_string_, std::move(request_header), - std::move(data), std::move(message_header.user_metadata), std::move(full_path), callback, manage_data_memory, timeout_ms_} + err = request_pool__->AddRequest(std::unique_ptr<ProducerRequest> { + new ProducerRequest{ + source_cred_string_, std::move(request_header), + std::move(data), std::move(message_header.user_metadata), std::move(full_path), callback, + manage_data_memory, timeout_ms_} }); return HandleErrorFromPool(std::move(err), manage_data_memory); @@ -194,7 +198,7 @@ Error ProducerImpl::SendStreamFinishedFlag(std::string stream, uint64_t last_id, if (next_stream.empty()) { next_stream = kNoNextStreamKeyword; } - message_header.user_metadata = std::string("{\"next_stream\":") + "\"" + next_stream + "\"}"; + message_header.user_metadata = std::string("{\"next_stream\":") + "\"" + next_stream + "\"}"; return Send(message_header, std::move(stream), nullptr, "", IngestModeFlags::kTransferMetaDataOnly, callback, true); } @@ -237,7 +241,7 @@ Error ProducerImpl::SetCredentials(SourceCredentials source_cred) { } source_cred_string_ = source_cred.GetString(); - if (source_cred_string_.size() + source_cred.user_token.size() > kMaxMessageSize) { + if (source_cred_string_.size() + source_cred.user_token.size() > kMaxMessageSize) { log__->Error("credentials string is too long - " + source_cred_string_); source_cred_string_ = ""; return ProducerErrorTemplates::kWrongInput.Generate("credentials string is too long"); @@ -251,9 +255,11 @@ Error ProducerImpl::SendMetadata(const std::string& metadata, RequestCallback ca request_header.custom_data[kPosIngestMode] = asapo::IngestModeFlags::kTransferData | asapo::IngestModeFlags::kStoreInDatabase; MessageData data{new uint8_t[metadata.size()]}; - strncpy((char*)data.get(), metadata.c_str(), metadata.size()); - auto err = request_pool__->AddRequest(std::unique_ptr<ProducerRequest> {new ProducerRequest{source_cred_string_, std::move(request_header), - std::move(data), "", "", callback, true, timeout_ms_} + strncpy((char*) data.get(), metadata.c_str(), metadata.size()); + auto err = request_pool__->AddRequest(std::unique_ptr<ProducerRequest> { + new ProducerRequest{ + source_cred_string_, std::move(request_header), + std::move(data), "", "", callback, true, timeout_ms_} }); return HandleErrorFromPool(std::move(err), true); } @@ -263,17 +269,23 @@ Error ProducerImpl::Send__(const MessageHeader& message_header, uint64_t ingest_mode, std::string stream, RequestCallback callback) { - MessageData data_wrapped = MessageData{(uint8_t*)data}; + MessageData data_wrapped = MessageData{(uint8_t*) data}; if (auto err = CheckData(ingest_mode, message_header, &data_wrapped)) { data_wrapped.release(); return err; } - return Send(std::move(message_header), std::move(stream), std::move(data_wrapped), "", ingest_mode, callback, false); + return Send(std::move(message_header), + std::move(stream), + std::move(data_wrapped), + "", + ingest_mode, + callback, + false); } -uint64_t ProducerImpl::GetRequestsQueueSize() { +uint64_t ProducerImpl::GetRequestsQueueSize() { return request_pool__->NRequestsInPool(); }; @@ -301,11 +313,10 @@ Error ProducerImpl::SendFile(const MessageHeader& message_header, } -template<class T > +template<class T> using RequestCallbackWithPromise = void (*)(std::shared_ptr<std::promise<T>>, RequestCallbackPayload header, Error err); - template<class T> RequestCallback unwrap_callback(RequestCallbackWithPromise<T> callback, std::unique_ptr<std::promise<T>> promise) { @@ -329,7 +340,7 @@ void ActivatePromiseForStreamInfo(std::shared_ptr<std::promise<StreamInfoResult> } try { promise->set_value(res); - } catch(...) {} + } catch (...) {} } void ActivatePromiseForErrorInterface(std::shared_ptr<std::promise<ErrorInterface*>> promise, @@ -343,10 +354,9 @@ void ActivatePromiseForErrorInterface(std::shared_ptr<std::promise<ErrorInterfac } try { promise->set_value(res); - } catch(...) {} + } catch (...) {} } - template<class T> T GetResultFromCallback(std::future<T>* promiseResult, uint64_t timeout_ms, Error* err) { try { @@ -354,13 +364,12 @@ T GetResultFromCallback(std::future<T>* promiseResult, uint64_t timeout_ms, Erro if (status == std::future_status::ready) { return promiseResult->get(); } - } catch(...) {} + } catch (...) {} *err = ProducerErrorTemplates::kTimeout.Generate(); return T{}; } - GenericRequestHeader CreateRequestHeaderFromOp(StreamRequestOp op, std::string stream) { switch (op) { case StreamRequestOp::kStreamInfo: @@ -373,15 +382,17 @@ GenericRequestHeader CreateRequestHeaderFromOp(StreamRequestOp op, std::string s StreamInfo ProducerImpl::StreamRequest(StreamRequestOp op, std::string stream, uint64_t timeout_ms, Error* err) const { auto header = CreateRequestHeaderFromOp(op, stream); - std::unique_ptr<std::promise<StreamInfoResult>> promise {new std::promise<StreamInfoResult>}; + std::unique_ptr<std::promise<StreamInfoResult>> promise{new std::promise<StreamInfoResult>}; std::future<StreamInfoResult> promiseResult = promise->get_future(); - *err = request_pool__->AddRequest(std::unique_ptr<ProducerRequest> {new ProducerRequest{source_cred_string_, std::move(header), - nullptr, "", "", - unwrap_callback( - ActivatePromiseForStreamInfo, - std::move(promise)), true, - timeout_ms} + *err = request_pool__->AddRequest(std::unique_ptr<ProducerRequest> { + new ProducerRequest{ + source_cred_string_, std::move(header), + nullptr, "", "", + unwrap_callback( + ActivatePromiseForStreamInfo, + std::move(promise)), true, + timeout_ms} }, true); if (*err) { return StreamInfo{}; @@ -435,7 +446,7 @@ Error ProducerImpl::GetServerVersionInfo(std::string* server_info, bool* supported) const { auto endpoint = endpoint_ + "/asapo-discovery/" + kProducerProtocol.GetDiscoveryVersion() + "/version?client=producer&protocol=" + kProducerProtocol.GetVersion(); - HttpCode code; + HttpCode code; Error err; auto response = httpclient__->Get(endpoint, &code, &err); if (err) { @@ -448,15 +459,17 @@ Error ProducerImpl::DeleteStream(std::string stream, uint64_t timeout_ms, Delete auto header = GenericRequestHeader{kOpcodeDeleteStream, 0, 0, 0, "", stream}; header.custom_data[0] = options.Encode(); - std::unique_ptr<std::promise<ErrorInterface*>> promise {new std::promise<ErrorInterface*>}; + std::unique_ptr<std::promise<ErrorInterface*>> promise{new std::promise<ErrorInterface*>}; std::future<ErrorInterface*> promiseResult = promise->get_future(); - auto err = request_pool__->AddRequest(std::unique_ptr<ProducerRequest> {new ProducerRequest{source_cred_string_, std::move(header), - nullptr, "", "", - unwrap_callback<ErrorInterface*>( - ActivatePromiseForErrorInterface, - std::move(promise)), true, - timeout_ms} + auto err = request_pool__->AddRequest(std::unique_ptr<ProducerRequest> { + new ProducerRequest{ + source_cred_string_, std::move(header), + nullptr, "", "", + unwrap_callback<ErrorInterface*>( + ActivatePromiseForErrorInterface, + std::move(promise)), true, + timeout_ms} }, true); if (err) { return err; diff --git a/producer/api/cpp/unittests/test_producer_impl.cpp b/producer/api/cpp/unittests/test_producer_impl.cpp index caf5a38e6369143ce0150868eadd88883853118b..233511dc7e3f8dbb445c6c49a63625211cb70383 100644 --- a/producer/api/cpp/unittests/test_producer_impl.cpp +++ b/producer/api/cpp/unittests/test_producer_impl.cpp @@ -523,10 +523,10 @@ TEST_F(ProducerImplTests, ReturnDataIfCanotAddToQueue) { TEST_F(ProducerImplTests, GetVersionInfoWithServer) { std::string result = - R"({"softwareVersion":"20.03.1, build 7a9294ad","clientSupported":"no", "clientProtocol":{"versionInfo":"v0.2"}})"; + R"({"softwareVersion":"21.06.0, build 7a9294ad","clientSupported":"no", "clientProtocol":{"versionInfo":"v0.3"}})"; EXPECT_CALL(*mock_http_client, Get_t(HasSubstr(expected_server_uri + - "/asapo-discovery/v0.1/version?client=producer&protocol=v0.2"), _, _)).WillOnce(DoAll( + "/asapo-discovery/v0.1/version?client=producer&protocol=v0.3"), _, _)).WillOnce(DoAll( SetArgPointee<1>(asapo::HttpCode::OK), SetArgPointee<2>(nullptr), Return(result))); @@ -534,8 +534,8 @@ TEST_F(ProducerImplTests, GetVersionInfoWithServer) { std::string client_info, server_info; auto err = producer.GetVersionInfo(&client_info, &server_info, nullptr); ASSERT_THAT(err, Eq(nullptr)); - ASSERT_THAT(server_info, HasSubstr("20.03.1")); - ASSERT_THAT(server_info, HasSubstr("v0.2")); + ASSERT_THAT(server_info, HasSubstr("21.06.0")); + ASSERT_THAT(server_info, HasSubstr("v0.3")); } MATCHER_P4(M_CheckDeleteStreamRequest, op_code, source_credentials, stream, flag, diff --git a/producer/api/cpp/unittests/test_producer_request.cpp b/producer/api/cpp/unittests/test_producer_request.cpp index 2e4509db30c2723ba5a933a1698d0edd53a64269..cf0b8aba9a20becec1928e260146699087d7cff4 100644 --- a/producer/api/cpp/unittests/test_producer_request.cpp +++ b/producer/api/cpp/unittests/test_producer_request.cpp @@ -40,7 +40,7 @@ TEST(ProducerRequest, Constructor) { uint64_t expected_file_size = 1337; uint64_t expected_meta_size = 137; std::string expected_meta = "meta"; - std::string expected_api_version = "v0.2"; + std::string expected_api_version = "v0.3"; asapo::Opcode expected_op_code = asapo::kOpcodeTransferData; asapo::GenericRequestHeader header{expected_op_code, expected_file_id, expected_file_size, diff --git a/producer/api/cpp/unittests/test_receiver_discovery_service.cpp b/producer/api/cpp/unittests/test_receiver_discovery_service.cpp index 956dc9df45e38a23bc0b89b8ca07dfd000eacbc5..d94988a48a1ee545ef9a473ff3684d35ea6d544f 100644 --- a/producer/api/cpp/unittests/test_receiver_discovery_service.cpp +++ b/producer/api/cpp/unittests/test_receiver_discovery_service.cpp @@ -48,7 +48,7 @@ class ReceiversStatusTests : public Test { NiceMock<asapo::MockLogger> mock_logger; NiceMock<MockHttpClient>* mock_http_client; - std::string expected_endpoint{"endpoint/asapo-discovery/v0.1/asapo-receiver?protocol=v0.2"}; + std::string expected_endpoint{"endpoint/asapo-discovery/v0.1/asapo-receiver?protocol=v0.3"}; ReceiverDiscoveryService status{"endpoint", 20}; void SetUp() override { diff --git a/receiver/CMakeLists.txt b/receiver/CMakeLists.txt index 0bc661d5f7eea49a3926a5bc71641d0adb24a4b0..4a18b189606b27a371224d9a12dead194b711ed8 100644 --- a/receiver/CMakeLists.txt +++ b/receiver/CMakeLists.txt @@ -21,7 +21,7 @@ set(RECEIVER_CORE_FILES src/request_handler/request_handler_db_last_stream.cpp src/request_handler/request_handler_receive_metadata.cpp src/request_handler/request_handler_db_check_request.cpp - src/request_handler/request_handler_delete_stream.cpp + src/request_handler/request_handler_db_delete_stream.cpp src/request_handler/request_factory.cpp src/request_handler/request_handler_db.cpp src/file_processors/write_file_processor.cpp diff --git a/receiver/src/request_handler/request_factory.h b/receiver/src/request_handler/request_factory.h index 374c586ab0e7609effa52cfe12017b471561734a..ee371d5aca70e23200079f9a6a3ce5072622cb34 100644 --- a/receiver/src/request_handler/request_factory.h +++ b/receiver/src/request_handler/request_factory.h @@ -6,7 +6,7 @@ #include "../file_processors/receive_file_processor.h" #include "request_handler_db_stream_info.h" #include "request_handler_db_last_stream.h" -#include "request_handler_delete_stream.h" +#include "request_handler_db_delete_stream.h" namespace asapo { @@ -26,7 +26,7 @@ class RequestFactory { RequestHandlerReceiveMetaData request_handler_receive_metadata_; RequestHandlerDbWrite request_handler_dbwrite_{kDBDataCollectionNamePrefix}; RequestHandlerDbStreamInfo request_handler_db_stream_info_{kDBDataCollectionNamePrefix}; - RequestHandlerDeleteStream request_handler_delete_stream_{kDBDataCollectionNamePrefix}; + RequestHandlerDbDeleteStream request_handler_delete_stream_{kDBDataCollectionNamePrefix}; RequestHandlerDbLastStream request_handler_db_last_stream_{kDBDataCollectionNamePrefix}; RequestHandlerDbMetaWrite request_handler_db_meta_write_{kDBMetaCollectionName}; RequestHandlerAuthorize request_handler_authorize_; diff --git a/receiver/src/request_handler/request_handler_db.cpp b/receiver/src/request_handler/request_handler_db.cpp index 821f0d770a551e27e2a1c0737e4902ce46e5519e..60ad1eb9623789cb486915db718590c613c11c9d 100644 --- a/receiver/src/request_handler/request_handler_db.cpp +++ b/receiver/src/request_handler/request_handler_db.cpp @@ -2,6 +2,7 @@ #include "../receiver_config.h" #include "../receiver_logger.h" #include "../request.h" +#include "asapo/database/db_error.h" namespace asapo { @@ -12,13 +13,13 @@ Error RequestHandlerDb::ProcessRequest(Request* request) const { db_name_ += "_" + data_source; } - return ConnectToDbIfNeeded(); } -RequestHandlerDb::RequestHandlerDb(std::string collection_name_prefix): log__{GetDefaultReceiverLogger()}, +RequestHandlerDb::RequestHandlerDb(std::string collection_name_prefix) : log__{GetDefaultReceiverLogger()}, http_client__{DefaultHttpClient()}, - collection_name_prefix_{std::move(collection_name_prefix)} { + collection_name_prefix_{ + std::move(collection_name_prefix)} { DatabaseFactory factory; Error err; db_client__ = factory.Create(&err); @@ -28,7 +29,6 @@ StatisticEntity RequestHandlerDb::GetStatisticEntity() const { return StatisticEntity::kDatabase; } - Error RequestHandlerDb::GetDatabaseServerUri(std::string* uri) const { if (GetReceiverConfig()->database_uri != "auto") { *uri = GetReceiverConfig()->database_uri; @@ -39,15 +39,17 @@ Error RequestHandlerDb::GetDatabaseServerUri(std::string* uri) const { Error http_err; *uri = http_client__->Get(GetReceiverConfig()->discovery_server + "/asapo-mongodb", &code, &http_err); if (http_err) { - log__->Error(std::string{"http error when discover database server "} + " from " + GetReceiverConfig()->discovery_server - + " : " + http_err->Explain()); + log__->Error( + std::string{"http error when discover database server "} + " from " + GetReceiverConfig()->discovery_server + + " : " + http_err->Explain()); return ReceiverErrorTemplates::kInternalServerError.Generate("http error when discover database server" + http_err->Explain()); } if (code != HttpCode::OK) { - log__->Error(std::string{"http error when discover database server "} + " from " + GetReceiverConfig()->discovery_server - + " : http code" + std::to_string((int)code)); + log__->Error( + std::string{"http error when discover database server "} + " from " + GetReceiverConfig()->discovery_server + + " : http code" + std::to_string((int) code)); return ReceiverErrorTemplates::kInternalServerError.Generate("error when discover database server"); } @@ -56,7 +58,6 @@ Error RequestHandlerDb::GetDatabaseServerUri(std::string* uri) const { return nullptr; } - Error RequestHandlerDb::ConnectToDbIfNeeded() const { if (!connected_to_db) { std::string uri; @@ -66,12 +67,24 @@ Error RequestHandlerDb::ConnectToDbIfNeeded() const { } err = db_client__->Connect(uri, db_name_); if (err) { - return ReceiverErrorTemplates::kInternalServerError.Generate("error connecting to database " + err->Explain()); + return DBErrorToReceiverError(err); } connected_to_db = true; } return nullptr; } +Error RequestHandlerDb::DBErrorToReceiverError(const Error& err) const { + if (err == nullptr) { + return nullptr; + } + std::string msg = "database error: " + err->Explain(); + if (err == DBErrorTemplates::kWrongInput || err == DBErrorTemplates::kNoRecord + || err == DBErrorTemplates::kJsonParseError) { + return ReceiverErrorTemplates::kBadRequest.Generate(msg); + } + + return ReceiverErrorTemplates::kInternalServerError.Generate(msg); +} } diff --git a/receiver/src/request_handler/request_handler_db.h b/receiver/src/request_handler/request_handler_db.h index 0ea67d20538c2a42eb7ca61872a0840e9faca457..0c006e6e3eda365143ccae52baaa71ecca77a783 100644 --- a/receiver/src/request_handler/request_handler_db.h +++ b/receiver/src/request_handler/request_handler_db.h @@ -21,6 +21,7 @@ class RequestHandlerDb : public ReceiverRequestHandler { std::unique_ptr<HttpClient> http_client__; protected: Error ConnectToDbIfNeeded() const; + Error DBErrorToReceiverError(const Error& err) const; mutable bool connected_to_db = false; mutable std::string db_name_; std::string collection_name_prefix_; diff --git a/receiver/src/request_handler/request_handler_db_check_request.cpp b/receiver/src/request_handler/request_handler_db_check_request.cpp index b43347fee946e7bc5a48cbb31296afcb65487271..3d33e1441c1eb2394ab4bcb75c1173cef5f11f0d 100644 --- a/receiver/src/request_handler/request_handler_db_check_request.cpp +++ b/receiver/src/request_handler/request_handler_db_check_request.cpp @@ -58,7 +58,7 @@ Error RequestHandlerDbCheckRequest::ProcessRequest(Request* request) const { MessageMeta record; auto err = GetRecordFromDb(request, &record); if (err) { - return err == DBErrorTemplates::kNoRecord ? nullptr : std::move(err); + return DBErrorToReceiverError(err == DBErrorTemplates::kNoRecord ? nullptr : std::move(err)); } if (SameRequestInRecord(request, record)) { diff --git a/receiver/src/request_handler/request_handler_delete_stream.cpp b/receiver/src/request_handler/request_handler_db_delete_stream.cpp similarity index 78% rename from receiver/src/request_handler/request_handler_delete_stream.cpp rename to receiver/src/request_handler/request_handler_db_delete_stream.cpp index 3f58f002f69724cd38908d6da2736ed55c6a4281..8ff55033d85e767d33626914f7c667a75ca5a94a 100644 --- a/receiver/src/request_handler/request_handler_delete_stream.cpp +++ b/receiver/src/request_handler/request_handler_db_delete_stream.cpp @@ -1,14 +1,14 @@ -#include "request_handler_delete_stream.h" +#include "request_handler_db_delete_stream.h" #include "../receiver_config.h" #include <asapo/database/db_error.h> namespace asapo { -RequestHandlerDeleteStream::RequestHandlerDeleteStream(std::string collection_name_prefix) : RequestHandlerDb( +RequestHandlerDbDeleteStream::RequestHandlerDbDeleteStream(std::string collection_name_prefix) : RequestHandlerDb( std::move(collection_name_prefix)) { } -Error RequestHandlerDeleteStream::ProcessRequest(Request* request) const { +Error RequestHandlerDbDeleteStream::ProcessRequest(Request* request) const { if (auto err = RequestHandlerDb::ProcessRequest(request) ) { return err; } @@ -36,7 +36,7 @@ Error RequestHandlerDeleteStream::ProcessRequest(Request* request) const { return nullptr; } - return err; + return DBErrorToReceiverError(err); } diff --git a/receiver/src/request_handler/request_handler_db_delete_stream.h b/receiver/src/request_handler/request_handler_db_delete_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..b710fb04bdb60a099f3d782889952dd5a9c4a822 --- /dev/null +++ b/receiver/src/request_handler/request_handler_db_delete_stream.h @@ -0,0 +1,18 @@ +#ifndef ASAPO_REQUEST_HANDLER_DB_DELETE_STREAM_H +#define ASAPO_REQUEST_HANDLER_DB_DELETE_STREAM_H + +#include "request_handler_db.h" +#include "../request.h" + +namespace asapo { + +class RequestHandlerDbDeleteStream final: public RequestHandlerDb { + public: + RequestHandlerDbDeleteStream(std::string collection_name_prefix); + Error ProcessRequest(Request* request) const override; +}; + +} + + +#endif //ASAPO_REQUEST_HANDLER_DB_DELETE_STREAM_H diff --git a/receiver/src/request_handler/request_handler_db_last_stream.cpp b/receiver/src/request_handler/request_handler_db_last_stream.cpp index fbe5b8035b64d203bcef4f70c58f3311a5cd9916..f1c5b959494e26d3c8e1c139be7fc59d4b0d1f15 100644 --- a/receiver/src/request_handler/request_handler_db_last_stream.cpp +++ b/receiver/src/request_handler/request_handler_db_last_stream.cpp @@ -22,7 +22,7 @@ Error RequestHandlerDbLastStream::ProcessRequest(Request* request) const { db_name_ + " at " + GetReceiverConfig()->database_uri); request->SetResponseMessage(info.Json(), ResponseMessageType::kInfo); } - return err; + return DBErrorToReceiverError(err); } } \ No newline at end of file diff --git a/receiver/src/request_handler/request_handler_db_meta_write.cpp b/receiver/src/request_handler/request_handler_db_meta_write.cpp index 57fb21051f846e708c858bb8396e5e968b44e867..3f1c89f0ff175626a3ad1a3cda47ece8ee32daf2 100644 --- a/receiver/src/request_handler/request_handler_db_meta_write.cpp +++ b/receiver/src/request_handler/request_handler_db_meta_write.cpp @@ -22,7 +22,7 @@ Error RequestHandlerDbMetaWrite::ProcessRequest(Request* request) const { db_name_ + " at " + GetReceiverConfig()->database_uri); } - return err; + return DBErrorToReceiverError(err); } RequestHandlerDbMetaWrite::RequestHandlerDbMetaWrite(std::string collection_name) : RequestHandlerDb(std::move( collection_name)) { diff --git a/receiver/src/request_handler/request_handler_db_stream_info.cpp b/receiver/src/request_handler/request_handler_db_stream_info.cpp index 65d194ccfa1f570fa51341d58e6e3b799a50528c..ff9a8ab935c208a986b41038a6be1ccd03e65c32 100644 --- a/receiver/src/request_handler/request_handler_db_stream_info.cpp +++ b/receiver/src/request_handler/request_handler_db_stream_info.cpp @@ -23,7 +23,7 @@ Error RequestHandlerDbStreamInfo::ProcessRequest(Request* request) const { info.name = request->GetStream(); request->SetResponseMessage(info.Json(), ResponseMessageType::kInfo); } - return err; + return DBErrorToReceiverError(err); } } \ No newline at end of file diff --git a/receiver/src/request_handler/request_handler_db_write.cpp b/receiver/src/request_handler/request_handler_db_write.cpp index d0113286aab69657d27e4ed7c69fdd272b411855..db28504e4f66991a18f257362d8e6e76ade59795 100644 --- a/receiver/src/request_handler/request_handler_db_write.cpp +++ b/receiver/src/request_handler/request_handler_db_write.cpp @@ -34,7 +34,7 @@ Error RequestHandlerDbWrite::ProcessRequest(Request* request) const { if (err == DBErrorTemplates::kDuplicateID) { return ProcessDuplicateRecordSituation(request); } else { - return err; + return DBErrorToReceiverError(err); } } diff --git a/receiver/src/request_handler/request_handler_delete_stream.h b/receiver/src/request_handler/request_handler_delete_stream.h deleted file mode 100644 index a9550da652a7969fe59c0487d01c2ae98ea73c0d..0000000000000000000000000000000000000000 --- a/receiver/src/request_handler/request_handler_delete_stream.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef ASAPO_REQUEST_HANDLER_DELETE_STREAM_H -#define ASAPO_REQUEST_HANDLER_DELETE_STREAM_H - -#include "request_handler_db.h" -#include "../request.h" - -namespace asapo { - -class RequestHandlerDeleteStream final: public RequestHandlerDb { - public: - RequestHandlerDeleteStream(std::string collection_name_prefix); - Error ProcessRequest(Request* request) const override; -}; - -} - - -#endif //ASAPO_REQUEST_HANDLER_DELETE_STREAM_H diff --git a/receiver/src/request_handler/requests_dispatcher.cpp b/receiver/src/request_handler/requests_dispatcher.cpp index 8d455393a5685e4352152b4f296dc05ab84e303c..ad009aba5addd6dbc66bec875df4900554c3905f 100644 --- a/receiver/src/request_handler/requests_dispatcher.cpp +++ b/receiver/src/request_handler/requests_dispatcher.cpp @@ -22,8 +22,7 @@ NetworkErrorCode GetNetworkCodeFromError(const Error& err) { return NetworkErrorCode::kNetErrorNotSupported; } else if (err == ReceiverErrorTemplates::kReAuthorizationFailure) { return NetworkErrorCode::kNetErrorReauthorize; - } else if (err == DBErrorTemplates::kJsonParseError || err == ReceiverErrorTemplates::kBadRequest - || err == DBErrorTemplates::kNoRecord) { + } else if (err == ReceiverErrorTemplates::kBadRequest) { return NetworkErrorCode::kNetErrorWrongRequest; } else { return NetworkErrorCode::kNetErrorInternalServerError; diff --git a/receiver/unittests/request_handler/test_request_factory.cpp b/receiver/unittests/request_handler/test_request_factory.cpp index 718ddd901b5e6da20f8591d9dc6b15e444e96b57..03d8bf4495f279493a1c9b90874c4a79b0e73ae8 100644 --- a/receiver/unittests/request_handler/test_request_factory.cpp +++ b/receiver/unittests/request_handler/test_request_factory.cpp @@ -15,7 +15,7 @@ #include "../../src/request_handler/request_handler_authorize.h" #include "../../src/request_handler/request_handler_db_stream_info.h" #include "../../src/request_handler/request_handler_db_last_stream.h" -#include "../../src/request_handler/request_handler_delete_stream.h" +#include "../../src/request_handler/request_handler_db_delete_stream.h" #include "../../src/request_handler/request_handler_receive_data.h" #include "../../src/request_handler/request_handler_receive_metadata.h" @@ -222,7 +222,7 @@ TEST_F(FactoryTests, DeleteStreamRequest) { ASSERT_THAT(err, Eq(nullptr)); ASSERT_THAT(request->GetListHandlers().size(), Eq(2)); ASSERT_THAT(dynamic_cast<const asapo::RequestHandlerAuthorize*>(request->GetListHandlers()[0]), Ne(nullptr)); - ASSERT_THAT(dynamic_cast<const asapo::RequestHandlerDeleteStream*>(request->GetListHandlers()[1]), Ne(nullptr)); + ASSERT_THAT(dynamic_cast<const asapo::RequestHandlerDbDeleteStream*>(request->GetListHandlers()[1]), Ne(nullptr)); } diff --git a/receiver/unittests/request_handler/test_request_handler_db_check_request.cpp b/receiver/unittests/request_handler/test_request_handler_db_check_request.cpp index 85c2b5da91967daf92d26146e08bd6eef6030e4f..05d43fca83158008680a543895734eb8dcf7e9f9 100644 --- a/receiver/unittests/request_handler/test_request_handler_db_check_request.cpp +++ b/receiver/unittests/request_handler/test_request_handler_db_check_request.cpp @@ -261,7 +261,7 @@ TEST_F(DbCheckRequestHandlerTests, ErrorIfDbError) { for (auto mock : mock_functions) { mock(asapo::DBErrorTemplates::kConnectionError.Generate().release(), false); auto err = handler.ProcessRequest(mock_request.get()); - ASSERT_THAT(err, Eq(asapo::DBErrorTemplates::kConnectionError)); + ASSERT_THAT(err, Eq(asapo::ReceiverErrorTemplates::kInternalServerError)); Mock::VerifyAndClearExpectations(mock_request.get()); } } diff --git a/receiver/unittests/request_handler/test_request_handler_delete_stream.cpp b/receiver/unittests/request_handler/test_request_handler_delete_stream.cpp index 94eaf86196578d3b54237ce68e3ed576faaba725..bb169b257b73b39600a8c553f4b3ba0e712bbf61 100644 --- a/receiver/unittests/request_handler/test_request_handler_delete_stream.cpp +++ b/receiver/unittests/request_handler/test_request_handler_delete_stream.cpp @@ -10,7 +10,7 @@ #include "../../src/request.h" #include "../../src/request_handler/request_factory.h" #include "../../src/request_handler/request_handler.h" -#include "../../src/request_handler/request_handler_delete_stream.h" +#include "../../src/request_handler/request_handler_db_delete_stream.h" #include "../../../common/cpp/src/database/mongodb_client.h" #include "../mock_receiver_config.h" @@ -43,7 +43,7 @@ using ::asapo::FileDescriptor; using ::asapo::SocketDescriptor; using ::asapo::MockIO; using asapo::Request; -using asapo::RequestHandlerDeleteStream; +using asapo::RequestHandlerDbDeleteStream; using ::asapo::GenericRequestHeader; using asapo::MockDatabase; @@ -56,7 +56,7 @@ namespace { class DbMetaDeleteStreamTests : public Test { public: - RequestHandlerDeleteStream handler{asapo::kDBDataCollectionNamePrefix}; + RequestHandlerDbDeleteStream handler{asapo::kDBDataCollectionNamePrefix}; std::unique_ptr<NiceMock<MockRequest>> mock_request; NiceMock<MockDatabase> mock_db; NiceMock<asapo::MockLogger> mock_logger; @@ -131,7 +131,7 @@ TEST_F(DbMetaDeleteStreamTests, CallsDeleteErrorAlreadyExist) { ExpectDelete(3, &asapo::DBErrorTemplates::kNoRecord); auto err = handler.ProcessRequest(mock_request.get()); - ASSERT_THAT(err, Eq(asapo::DBErrorTemplates::kNoRecord)); + ASSERT_THAT(err, Eq(asapo::ReceiverErrorTemplates::kBadRequest)); } TEST_F(DbMetaDeleteStreamTests, CallsDeleteNoErrorAlreadyExist) { diff --git a/receiver/unittests/request_handler/test_requests_dispatcher.cpp b/receiver/unittests/request_handler/test_requests_dispatcher.cpp index 73ade897611b663cd98209d65c51898764d296f6..9454aba1564b61886b4b37d78a1e2d8cf80c2a93 100644 --- a/receiver/unittests/request_handler/test_requests_dispatcher.cpp +++ b/receiver/unittests/request_handler/test_requests_dispatcher.cpp @@ -317,17 +317,6 @@ TEST_F(RequestsDispatcherTests, ProcessRequestReturnsReAuthorizationFailure) { } -TEST_F(RequestsDispatcherTests, ProcessRequestReturnsMetaDataFailure) { - MockHandleRequest(1, asapo::DBErrorTemplates::kJsonParseError.Generate()); - MockSendResponse(&response, false); - - auto err = dispatcher->ProcessRequest(request); - - ASSERT_THAT(err, Eq(asapo::DBErrorTemplates::kJsonParseError)); - ASSERT_THAT(response.error_code, Eq(asapo::kNetErrorWrongRequest)); - ASSERT_THAT(std::string(response.message), HasSubstr("parse")); -} - TEST_F(RequestsDispatcherTests, ProcessRequestReturnsBadRequest) { MockHandleRequest(1, asapo::ReceiverErrorTemplates::kBadRequest.Generate()); MockSendResponse(&response, false); @@ -338,14 +327,5 @@ TEST_F(RequestsDispatcherTests, ProcessRequestReturnsBadRequest) { } -TEST_F(RequestsDispatcherTests, ProcessRequestReturnsBNoRecord) { - MockHandleRequest(1, asapo::DBErrorTemplates::kNoRecord.Generate()); - MockSendResponse(&response, false); - - auto err = dispatcher->ProcessRequest(request); - - ASSERT_THAT(response.error_code, Eq(asapo::kNetErrorWrongRequest)); -} - } diff --git a/tests/automatic/consumer/consumer_api_python/consumer_api.py b/tests/automatic/consumer/consumer_api_python/consumer_api.py index e15030262a6599d559e00e276d660a85f2cf8449..54baf5d799b45790d3ca803aeaa3f78927c29ed6 100644 --- a/tests/automatic/consumer/consumer_api_python/consumer_api.py +++ b/tests/automatic/consumer/consumer_api_python/consumer_api.py @@ -118,14 +118,6 @@ def check_single(consumer, group_id): assert_metaname(meta, "5", "get next6") assert_usermetadata(meta, "get next6") - try: - consumer.get_next("_wrong_group_name", meta_only=True) - except asapo_consumer.AsapoWrongInputError as err: - print(err) - pass - else: - exit_on_noerr("should give wrong input error") - try: consumer.get_last(meta_only=False) except asapo_consumer.AsapoLocalIOError as err: diff --git a/tests/automatic/mongo_db/insert_retrieve/cleanup_linux.sh b/tests/automatic/mongo_db/insert_retrieve/cleanup_linux.sh index adf03ce9569285fb5f692cc9f9c96685cc1268d0..71093a2577be0c3a864a60150290b04338c6f057 100644 --- a/tests/automatic/mongo_db/insert_retrieve/cleanup_linux.sh +++ b/tests/automatic/mongo_db/insert_retrieve/cleanup_linux.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -database_name=data +database_name=data_%2F%20%5C%2E%22%24 echo "db.dropDatabase()" | mongo ${database_name} diff --git a/tests/automatic/mongo_db/insert_retrieve/cleanup_windows.bat b/tests/automatic/mongo_db/insert_retrieve/cleanup_windows.bat index 53b3cbee6ad380a90cb12999ff08e6724fe90d7b..8a64516d46643ebc6910c3c6b8e2248f7e779ac1 100644 --- a/tests/automatic/mongo_db/insert_retrieve/cleanup_windows.bat +++ b/tests/automatic/mongo_db/insert_retrieve/cleanup_windows.bat @@ -1,4 +1,4 @@ -SET database_name=data +SET database_name=data_%%2F%%20%%5C%%2E%%22%%24 SET mongo_exe="c:\Program Files\MongoDB\Server\4.2\bin\mongo.exe" echo db.dropDatabase() | %mongo_exe% %database_name% diff --git a/tests/automatic/mongo_db/insert_retrieve/insert_retrieve_mongodb.cpp b/tests/automatic/mongo_db/insert_retrieve/insert_retrieve_mongodb.cpp index 4da3b42cc829861f6e85d76ceb4aa9a9a3344918..c6596be84bff2e03b63d2a50c4e8b271c5cdd622 100644 --- a/tests/automatic/mongo_db/insert_retrieve/insert_retrieve_mongodb.cpp +++ b/tests/automatic/mongo_db/insert_retrieve/insert_retrieve_mongodb.cpp @@ -3,6 +3,8 @@ #include <thread> #include "../../../common/cpp/src/database/mongodb_client.h" +#include "asapo/database/db_error.h" + #include "testing.h" #include "asapo/common/data_structs.h" @@ -32,6 +34,19 @@ Args GetArgs(int argc, char* argv[]) { return Args{argv[1], atoi(argv[2])}; } +std::string GenRandomString(int len) { + std::string s; + static const char alphanum[] = + "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + + for (int i = 0; i < len; ++i) { + s += alphanum[rand() % (sizeof(alphanum) - 1)]; + } + + return s; +} int main(int argc, char* argv[]) { auto args = GetArgs(argc, argv); @@ -46,15 +61,18 @@ int main(int argc, char* argv[]) { fi.source = "host:1234"; + auto db_name = R"(data_/ \."$)"; + auto stream_name = R"(bla/test_/\ ."$)"; + if (args.keyword != "Notconnected") { - db.Connect("127.0.0.1", "data"); + db.Connect("127.0.0.1", db_name); } - auto err = db.Insert("data_test", fi, false); + auto err = db.Insert(std::string("data_") + stream_name, fi, false); if (args.keyword == "DuplicateID") { Assert(err, "OK"); - err = db.Insert("data_test", fi, false); + err = db.Insert(std::string("data_") + stream_name, fi, false); } std::this_thread::sleep_for(std::chrono::milliseconds(10)); @@ -73,17 +91,17 @@ int main(int argc, char* argv[]) { if (args.keyword == "OK") { // check retrieve and stream delete asapo::MessageMeta fi_db; asapo::MongoDBClient db_new; - db_new.Connect("127.0.0.1", "data"); - err = db_new.GetById("data_test", fi.id, &fi_db); + db_new.Connect("127.0.0.1", db_name); + err = db_new.GetById(std::string("data_") + stream_name, fi.id, &fi_db); M_AssertTrue(fi_db == fi, "get record from db"); M_AssertEq(nullptr, err); - err = db_new.GetById("data_test", 0, &fi_db); + err = db_new.GetById(std::string("data_") + stream_name, 0, &fi_db); Assert(err, "No record"); asapo::StreamInfo info; - err = db.GetStreamInfo("data_test", &info); + err = db.GetStreamInfo(std::string("data_") + stream_name, &info); M_AssertEq(nullptr, err); M_AssertEq(fi.id, info.last_id); @@ -95,24 +113,37 @@ int main(int argc, char* argv[]) { M_AssertEq("ns", info.next_stream); // delete stream - db.Insert("inprocess_test_blabla", fi, false); - db.Insert("inprocess_test_blabla1", fi, false); - db.Insert("acks_test_blabla", fi, false); - db.Insert("acks_test_blabla1", fi, false); - db.DeleteStream("test"); - err = db.GetStreamInfo("data_test", &info); + db.Insert(std::string("inprocess_") + stream_name + "_blabla", fi, false); + db.Insert(std::string("inprocess_") + stream_name + "_blabla1", fi, false); + db.Insert(std::string("acks_") + stream_name + "_blabla", fi, false); + db.Insert(std::string("acks_") + stream_name + "_blabla1", fi, false); + db.DeleteStream(stream_name); + err = db.GetStreamInfo(std::string("data_") + stream_name, &info); M_AssertTrue(info.last_id == 0); - err = db.GetStreamInfo("inprocess_test_blabla", &info); + err = db.GetStreamInfo(std::string("inprocess_") + stream_name + "_blabla", &info); M_AssertTrue(info.last_id == 0); - err = db.GetStreamInfo("inprocess_test_blabla1", &info); + err = db.GetStreamInfo(std::string("inprocess_") + stream_name + "_blabla1", &info); M_AssertTrue(info.last_id == 0); - err = db.GetStreamInfo("acks_test_blabla", &info); + err = db.GetStreamInfo(std::string("acks_") + stream_name + "_blabla", &info); M_AssertTrue(info.last_id == 0); - err = db.GetStreamInfo("acks_test_blabla1", &info); + err = db.GetStreamInfo(std::string("acks_") + stream_name + "_blabla1", &info); M_AssertTrue(info.last_id == 0); err = db.DeleteStream("test1"); M_AssertTrue(err == nullptr); } + // long names + + asapo::MongoDBClient db1; + auto long_db_name = GenRandomString(64); + err = db1.Connect("127.0.0.1", long_db_name); + M_AssertTrue(err == asapo::DBErrorTemplates::kWrongInput); + + db1.Connect("127.0.0.1", db_name); + auto long_stream_name = GenRandomString(120); + err = db1.Insert(long_stream_name, fi, true); + M_AssertTrue(err == asapo::DBErrorTemplates::kWrongInput); + + return 0; } diff --git a/tests/automatic/producer/python_api/producer_api.py b/tests/automatic/producer/python_api/producer_api.py index b4061c3cf08e8127721437b49a2772c886f5c8a5..64e30bc439739f6e32d9481271c002739d9f717a 100644 --- a/tests/automatic/producer/python_api/producer_api.py +++ b/tests/automatic/producer/python_api/producer_api.py @@ -119,7 +119,7 @@ producer.wait_requests_finished(50000) # send to another stream producer.send(1, "processed/" + data_source + "/" + "file9", None, - ingest_mode=asapo_producer.INGEST_MODE_TRANSFER_METADATA_ONLY, stream="stream", callback=callback) + ingest_mode=asapo_producer.INGEST_MODE_TRANSFER_METADATA_ONLY, stream="stream/test $", callback=callback) # wait normal requests finished before sending duplicates @@ -149,7 +149,7 @@ assert_eq(n, 0, "requests in queue") # send another data to stream stream producer.send(2, "processed/" + data_source + "/" + "file10", None, - ingest_mode=asapo_producer.INGEST_MODE_TRANSFER_METADATA_ONLY, stream="stream", callback=callback) + ingest_mode=asapo_producer.INGEST_MODE_TRANSFER_METADATA_ONLY, stream="stream/test $", callback=callback) producer.wait_requests_finished(50000) n = producer.get_requests_queue_size() @@ -168,9 +168,9 @@ else: #stream_finished producer.wait_requests_finished(10000) -producer.send_stream_finished_flag("stream", 2, next_stream = "next_stream", callback = callback) +producer.send_stream_finished_flag("stream/test $", 2, next_stream = "next_stream", callback = callback) # check callback_object.callback works, will be duplicated request -producer.send_stream_finished_flag("stream", 2, next_stream = "next_stream", callback = callback_object.callback) +producer.send_stream_finished_flag("stream/test $", 2, next_stream = "next_stream", callback = callback_object.callback) producer.wait_requests_finished(10000) @@ -185,7 +185,7 @@ assert_eq(info['timestampLast']/1000000000>time.time()-10,True , "stream_info ti print("created: ",datetime.utcfromtimestamp(info['timestampCreated']/1000000000).strftime('%Y-%m-%d %H:%M:%S.%f')) print("last record: ",datetime.utcfromtimestamp(info['timestampLast']/1000000000).strftime('%Y-%m-%d %H:%M:%S.%f')) -info = producer.stream_info('stream') +info = producer.stream_info('stream/test $') assert_eq(info['lastId'], 3, "last id from different stream") assert_eq(info['finished'], True, "stream finished") @@ -199,12 +199,12 @@ assert_eq(info['lastId'], 0, "last id from non existing stream") info_last = producer.last_stream() print(info_last) -assert_eq(info_last['name'], "stream", "last stream") +assert_eq(info_last['name'], "stream/test $", "last stream") assert_eq(info_last['timestampCreated'] <= info_last['timestampLast'], True, "last is later than first") #delete_streams -producer.delete_stream('stream') -producer.stream_info('stream') +producer.delete_stream('stream/test $') +producer.stream_info('stream/test $') assert_eq(info['lastId'], 0, "last id from non deleted stream")