Skip to content

Commit

Permalink
add comments + fix linter
Browse files Browse the repository at this point in the history
  • Loading branch information
gmgigi96 committed Feb 14, 2023
1 parent 1983f5e commit 9c0d95d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 12 deletions.
18 changes: 18 additions & 0 deletions pkg/ocm/share/repository/sql/conversions.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,52 @@ import (
"github.com/cs3org/reva/pkg/ocm/share"
)

// ShareType is the type of the share.
type ShareType int

// AccessMethod is method granted by the sharer to access
// the shared resource.
type AccessMethod int

// Protocol is the protocol the recipient of the share
// uses to access the shared resource.
type Protocol int

// ShareState is the state of the share.
type ShareState int

const (
// ShareTypeUser is used for a share to an user.
ShareTypeUser ShareType = iota
// ShareTypeGroup is used for a share to a group.
ShareTypeGroup
)

const (
// ShareStatePending is the state for a pending share.
ShareStatePending ShareState = iota
// ShareStateAccepted is the share for an accepted share.
ShareStateAccepted
// ShareStateRejected is the share for a rejected share.
ShareStateRejected
)

const (
// WebDAVAccessMethod indicates an access using WebDAV to the share.
WebDAVAccessMethod AccessMethod = iota
// WebappAccessMethod indicates an access using a collaborative
// application to the share.
WebappAccessMethod
// TransferAccessMethod indicates a share for a transfer.
TransferAccessMethod
)

const (
// WebDAVProtocol is the WebDav protocol.
WebDAVProtocol Protocol = iota
// WebappProtcol is the Webapp protocol.
WebappProtcol
// TransferProtocol is the Transfer protocol.
TransferProtocol
)

Expand Down
26 changes: 14 additions & 12 deletions pkg/ocm/share/repository/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func init() {
registry.Register("sql", New)
}

// New creates a Repository with a SQL driver.
func New(c map[string]interface{}) (share.Repository, error) {
conf, err := parseConfig(c)
if err != nil {
Expand Down Expand Up @@ -80,7 +81,7 @@ func parseConfig(conf map[string]interface{}) (*config, error) {
return &c, nil
}

func formatUserId(u *userpb.UserId) string {
func formatUserID(u *userpb.UserId) string {
return fmt.Sprintf("%s@%s", u.OpaqueId, u.Idp)
}

Expand Down Expand Up @@ -128,10 +129,10 @@ func storeAccessMethod(tx *sql.Tx, shareID int64, t AccessMethod) (int64, error)

// StoreShare stores a share.
func (m *mgr) StoreShare(ctx context.Context, s *ocm.Share) (*ocm.Share, error) {
if err := Transaction(ctx, m.db, func(tx *sql.Tx) error {
if err := transaction(ctx, m.db, func(tx *sql.Tx) error {
// store the share
query := "INSERT INTO ocm_shares SET token=?,fileid_prefix=?,item_source=?,name=?,share_with=?,owner=?,initiator=?,ctime=?,mtime=?,type=?"
params := []any{s.Token, s.ResourceId.StorageId, s.ResourceId.OpaqueId, s.Name, formatUserId(s.Grantee.GetUserId()), s.Owner.OpaqueId, s.Creator.OpaqueId, s.Ctime.Seconds, s.Mtime.Seconds, convertFromCS3OCMShareType(s.ShareType)}
params := []any{s.Token, s.ResourceId.StorageId, s.ResourceId.OpaqueId, s.Name, formatUserID(s.Grantee.GetUserId()), s.Owner.OpaqueId, s.Creator.OpaqueId, s.Ctime.Seconds, s.Mtime.Seconds, convertFromCS3OCMShareType(s.ShareType)}

if s.Expiration != nil {
query += ",expiration=?"
Expand Down Expand Up @@ -168,7 +169,6 @@ func (m *mgr) StoreShare(ctx context.Context, s *ocm.Share) (*ocm.Share, error)

s.Id = &ocm.ShareId{OpaqueId: strconv.FormatInt(id, 10)}
return nil

}); err != nil {
// check if the share already exists in the db
// https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html#error_er_dup_unique
Expand All @@ -182,7 +182,9 @@ func (m *mgr) StoreShare(ctx context.Context, s *ocm.Share) (*ocm.Share, error)
return s, nil
}

func Transaction(ctx context.Context, db *sql.DB, f func(*sql.Tx) error) error {
// this func will run f in a transaction, committing if no errors
// rolling back if there were error running f.
func transaction(ctx context.Context, db *sql.DB, f func(*sql.Tx) error) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
Expand All @@ -191,9 +193,9 @@ func Transaction(ctx context.Context, db *sql.DB, f func(*sql.Tx) error) error {
var txErr error
defer func() {
if txErr == nil {
tx.Commit()
_ = tx.Commit()
} else {
tx.Rollback()
_ = tx.Rollback()
}
}()

Expand Down Expand Up @@ -242,7 +244,7 @@ func (m *mgr) getByKey(ctx context.Context, user *userpb.User, key *ocm.ShareKey
query := "SELECT id, token, fileid_prefix, item_source, name, share_with, owner, initiator, ctime, mtime, expiration, type FROM ocm_shares WHERE owner=? AND fileid_prefix=? AND item_source=? AND share_with=? AND (initiator=? OR owner=?)"

var s dbShare
if err := m.db.QueryRowContext(ctx, query, key.Owner.OpaqueId, key.ResourceId.StorageId, key.ResourceId.OpaqueId, formatUserId(key.Grantee.GetUserId()), user.Id.OpaqueId, user.Id.OpaqueId).Scan(&s.ID, &s.Token, &s.Prefix, &s.ItemSource, &s.Name, &s.ShareWith, &s.Owner, &s.Initiator, &s.Ctime, &s.Mtime, &s.Expiration, &s.ShareType); err != nil {
if err := m.db.QueryRowContext(ctx, query, key.Owner.OpaqueId, key.ResourceId.StorageId, key.ResourceId.OpaqueId, formatUserID(key.Grantee.GetUserId()), user.Id.OpaqueId, user.Id.OpaqueId).Scan(&s.ID, &s.Token, &s.Prefix, &s.ItemSource, &s.Name, &s.ShareWith, &s.Owner, &s.Initiator, &s.Ctime, &s.Mtime, &s.Expiration, &s.ShareType); err != nil {
if err == sql.ErrNoRows {
return nil, share.ErrShareNotFound
}
Expand Down Expand Up @@ -300,7 +302,7 @@ func (m *mgr) deleteByID(ctx context.Context, user *userpb.User, id *ocm.ShareId

func (m *mgr) deleteByKey(ctx context.Context, user *userpb.User, key *ocm.ShareKey) error {
query := "DELETE FROM ocm_shares WHERE owner=? AND fileid_prefix=? AND item_source=? AND share_with=? AND (initiator=? OR owner=?)"
_, err := m.db.ExecContext(ctx, query, key.Owner.OpaqueId, key.ResourceId.StorageId, key.ResourceId.OpaqueId, formatUserId(key.Grantee.GetUserId()), user.Id.OpaqueId, user.Id.OpaqueId)
_, err := m.db.ExecContext(ctx, query, key.Owner.OpaqueId, key.ResourceId.StorageId, key.ResourceId.OpaqueId, formatUserID(key.Grantee.GetUserId()), user.Id.OpaqueId, user.Id.OpaqueId)
return err
}

Expand Down Expand Up @@ -490,9 +492,9 @@ func storeProtocol(tx *sql.Tx, shareID int64, p Protocol) (int64, error) {

// StoreReceivedShare stores a received share.
func (m *mgr) StoreReceivedShare(ctx context.Context, s *ocm.ReceivedShare) (*ocm.ReceivedShare, error) {
if err := Transaction(ctx, m.db, func(tx *sql.Tx) error {
if err := transaction(ctx, m.db, func(tx *sql.Tx) error {
query := "INSERT INTO ocm_received_shares SET name=?,fileid_prefix=?,item_source=?,share_with=?,owner=?,initiator=?,ctime=?,mtime=?,type=?,state=?"
params := []any{s.Name, s.ResourceId.StorageId, s.ResourceId.OpaqueId, s.Grantee.GetUserId().OpaqueId, formatUserId(s.Owner), formatUserId(s.Creator), s.Ctime.Seconds, s.Mtime.Seconds, convertFromCS3OCMShareType(s.ShareType), convertFromCS3OCMShareState(s.State)}
params := []any{s.Name, s.ResourceId.StorageId, s.ResourceId.OpaqueId, s.Grantee.GetUserId().OpaqueId, formatUserID(s.Owner), formatUserID(s.Creator), s.Ctime.Seconds, s.Mtime.Seconds, convertFromCS3OCMShareType(s.ShareType), convertFromCS3OCMShareState(s.State)}

if s.Expiration != nil {
query += ",expiration=?"
Expand Down Expand Up @@ -642,7 +644,7 @@ func (m *mgr) getReceivedByID(ctx context.Context, user *userpb.User, id *ocm.Sh

func (m *mgr) getReceivedByKey(ctx context.Context, user *userpb.User, key *ocm.ShareKey) (*ocm.ReceivedShare, error) {
query := "SELECT id, name, fileid_prefix, item_source, share_with, owner, initiator, ctime, mtime, expiration, type, state FROM ocm_received_shares WHERE owner=? AND fileid_prefix=? AND item_source=? AND share_with=?"
params := []any{formatUserId(key.Owner), key.ResourceId.StorageId, key.ResourceId.OpaqueId, key.Grantee.GetUserId().OpaqueId}
params := []any{formatUserID(key.Owner), key.ResourceId.StorageId, key.ResourceId.OpaqueId, key.Grantee.GetUserId().OpaqueId}

var s dbReceivedShare
if err := m.db.QueryRowContext(ctx, query, params...).Scan(&s.ID, &s.Name, &s.Prefix, &s.ItemSource, &s.ShareWith, &s.Owner, &s.Initiator, &s.Ctime, &s.Mtime, &s.Expiration, &s.Type, &s.State); err != nil {
Expand Down
18 changes: 18 additions & 0 deletions pkg/ocm/share/repository/sql/sql_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
// Copyright 2018-2023 CERN
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// In applying this license, CERN does not waive the privileges and immunities
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.

package sql

import (
Expand Down

0 comments on commit 9c0d95d

Please sign in to comment.