Skip to content

Commit

Permalink
fix: enforce consistency of contract negotiation request and transfer…
Browse files Browse the repository at this point in the history
… request (#4264)
  • Loading branch information
bscholtes1A authored Jun 12, 2024
1 parent cb2e6b5 commit d931c8b
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ public PolicyDefinitionService policyDefinitionService() {
@Provider
public TransferProcessService transferProcessService() {
return new TransferProcessServiceImpl(transferProcessStore, transferProcessManager, transactionContext,
dataAddressValidator, commandHandlerRegistry, flowTypeExtractor);
dataAddressValidator, commandHandlerRegistry, flowTypeExtractor, contractNegotiationStore);
}

@Provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package org.eclipse.edc.connector.controlplane.services.transferprocess;

import org.eclipse.edc.connector.controlplane.contract.spi.negotiation.store.ContractNegotiationStore;
import org.eclipse.edc.connector.controlplane.services.query.QueryValidator;
import org.eclipse.edc.connector.controlplane.services.spi.transferprocess.TransferProcessService;
import org.eclipse.edc.connector.controlplane.transfer.spi.TransferProcessManager;
Expand Down Expand Up @@ -61,16 +62,19 @@ public class TransferProcessServiceImpl implements TransferProcessService {
private final DataAddressValidatorRegistry dataAddressValidator;
private final CommandHandlerRegistry commandHandlerRegistry;
private final FlowTypeExtractor flowTypeExtractor;
private final ContractNegotiationStore contractNegotiationStore;

public TransferProcessServiceImpl(TransferProcessStore transferProcessStore, TransferProcessManager manager,
TransactionContext transactionContext, DataAddressValidatorRegistry dataAddressValidator,
CommandHandlerRegistry commandHandlerRegistry, FlowTypeExtractor flowTypeExtractor) {
CommandHandlerRegistry commandHandlerRegistry, FlowTypeExtractor flowTypeExtractor,
ContractNegotiationStore contractNegotiationStore) {
this.transferProcessStore = transferProcessStore;
this.manager = manager;
this.transactionContext = transactionContext;
this.dataAddressValidator = dataAddressValidator;
this.commandHandlerRegistry = commandHandlerRegistry;
this.flowTypeExtractor = flowTypeExtractor;
this.contractNegotiationStore = contractNegotiationStore;
queryValidator = new QueryValidator(TransferProcess.class, getSubtypes());
}

Expand Down Expand Up @@ -123,6 +127,15 @@ public ServiceResult<List<TransferProcess>> search(QuerySpec query) {

@Override
public @NotNull ServiceResult<TransferProcess> initiateTransfer(TransferRequest request) {
var agreement = contractNegotiationStore.findContractAgreement(request.getContractId());
if (agreement == null) {
return ServiceResult.badRequest("Contract agreement with id %s not found".formatted(request.getContractId()));
}

if (!agreement.getAssetId().equals(request.getAssetId())) {
return ServiceResult.badRequest("Asset id %s in contract agreement does not match asset id in transfer request %s".formatted(agreement.getAssetId(), request.getAssetId()));
}

var flowType = flowTypeExtractor.extract(request.getTransferType()).getContent();

if (flowType == FlowType.PUSH) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,22 @@ void shouldDispatchEventsOnTransferProcessStateChanges(TransferProcessService se

when(identityService.verifyJwtToken(eq(tokenRepresentation), isA(VerificationContext.class))).thenReturn(Result.success(token));

var transferRequest = createTransferRequest();
var agent = mock(ParticipantAgent.class);
var agreement = mock(ContractAgreement.class);
var providerId = "ProviderId";

when(agreement.getAssetId()).thenReturn(transferRequest.getAssetId());
when(agreement.getProviderId()).thenReturn(providerId);
when(agreement.getPolicy()).thenReturn(Policy.Builder.newInstance().build());
when(agent.getIdentity()).thenReturn(providerId);

dispatcherRegistry.register(getTestDispatcher());
when(policyArchive.findPolicyForContract(matches("contractId"))).thenReturn(mock(Policy.class));
when(negotiationStore.findContractAgreement("contractId")).thenReturn(agreement);
when(policyArchive.findPolicyForContract(matches(transferRequest.getContractId()))).thenReturn(Policy.Builder.newInstance().target(transferRequest.getAssetId()).build());
when(negotiationStore.findContractAgreement(transferRequest.getContractId())).thenReturn(agreement);
when(agentService.createFor(token)).thenReturn(agent);
eventRouter.register(TransferProcessEvent.class, eventSubscriber);

var transferRequest = createTransferRequest();

var initiateResult = service.initiateTransfer(transferRequest);

await().atMost(TIMEOUT).untilAsserted(() -> {
Expand Down Expand Up @@ -196,10 +196,13 @@ void shouldDispatchEventsOnTransferProcessStateChanges(TransferProcessService se
}

@Test
void shouldTerminateOnInvalidPolicy(TransferProcessService service, EventRouter eventRouter, RemoteMessageDispatcherRegistry dispatcherRegistry) {
void shouldTerminateOnInvalidPolicy(TransferProcessService service, EventRouter eventRouter, RemoteMessageDispatcherRegistry dispatcherRegistry, ContractNegotiationStore negotiationStore) {
dispatcherRegistry.register(getTestDispatcher());
eventRouter.register(TransferProcessEvent.class, eventSubscriber);
var transferRequest = createTransferRequest();
var agreement = mock(ContractAgreement.class);
when(agreement.getAssetId()).thenReturn(transferRequest.getAssetId());
when(negotiationStore.findContractAgreement(transferRequest.getContractId())).thenReturn(agreement);

service.initiateTransfer(transferRequest);

Expand All @@ -213,12 +216,16 @@ void shouldTerminateOnInvalidPolicy(TransferProcessService service, EventRouter
void shouldDispatchEventOnTransferProcessTerminated(TransferProcessService service,
EventRouter eventRouter,
RemoteMessageDispatcherRegistry dispatcherRegistry,
PolicyArchive policyArchive) {
PolicyArchive policyArchive,
ContractNegotiationStore negotiationStore) {

when(policyArchive.findPolicyForContract(matches("contractId"))).thenReturn(mock(Policy.class));
var transferRequest = createTransferRequest();
when(policyArchive.findPolicyForContract(matches("contractId"))).thenReturn(Policy.Builder.newInstance().target(transferRequest.getAssetId()).build());
var agreement = mock(ContractAgreement.class);
when(agreement.getAssetId()).thenReturn(transferRequest.getAssetId());
when(negotiationStore.findContractAgreement(transferRequest.getContractId())).thenReturn(agreement);
dispatcherRegistry.register(getTestDispatcher());
eventRouter.register(TransferProcessEvent.class, eventSubscriber);
var transferRequest = createTransferRequest();

var initiateResult = service.initiateTransfer(transferRequest);

Expand All @@ -234,10 +241,13 @@ void shouldDispatchEventOnTransferProcessTerminated(TransferProcessService servi
}

@Test
void shouldDispatchEventOnTransferProcessFailure(TransferProcessService service, EventRouter eventRouter, RemoteMessageDispatcherRegistry dispatcherRegistry) {
void shouldDispatchEventOnTransferProcessFailure(TransferProcessService service, EventRouter eventRouter, RemoteMessageDispatcherRegistry dispatcherRegistry, ContractNegotiationStore negotiationStore) {
dispatcherRegistry.register(getTestDispatcher());
eventRouter.register(TransferProcessEvent.class, eventSubscriber);
var transferRequest = createTransferRequest();
var agreement = mock(ContractAgreement.class);
when(agreement.getAssetId()).thenReturn(transferRequest.getAssetId());
when(negotiationStore.findContractAgreement(transferRequest.getContractId())).thenReturn(agreement);

service.initiateTransfer(transferRequest);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

package org.eclipse.edc.connector.controlplane.services.transferprocess;

import org.eclipse.edc.connector.controlplane.contract.spi.negotiation.store.ContractNegotiationStore;
import org.eclipse.edc.connector.controlplane.contract.spi.types.agreement.ContractAgreement;
import org.eclipse.edc.connector.controlplane.services.spi.transferprocess.TransferProcessService;
import org.eclipse.edc.connector.controlplane.transfer.spi.TransferProcessManager;
import org.eclipse.edc.connector.controlplane.transfer.spi.flow.FlowTypeExtractor;
Expand All @@ -26,6 +28,7 @@
import org.eclipse.edc.connector.controlplane.transfer.spi.types.command.ResumeTransferCommand;
import org.eclipse.edc.connector.controlplane.transfer.spi.types.command.SuspendTransferCommand;
import org.eclipse.edc.connector.controlplane.transfer.spi.types.command.TerminateTransferCommand;
import org.eclipse.edc.policy.model.Policy;
import org.eclipse.edc.spi.command.CommandHandlerRegistry;
import org.eclipse.edc.spi.command.CommandResult;
import org.eclipse.edc.spi.query.Criterion;
Expand Down Expand Up @@ -77,9 +80,10 @@ class TransferProcessServiceImplTest {
private final DataAddressValidatorRegistry dataAddressValidator = mock();
private final CommandHandlerRegistry commandHandlerRegistry = mock();
private final FlowTypeExtractor flowTypeExtractor = mock();
private final ContractNegotiationStore contractNegotiationStore = mock();

private final TransferProcessService service = new TransferProcessServiceImpl(store, manager, transactionContext,
dataAddressValidator, commandHandlerRegistry, flowTypeExtractor);
dataAddressValidator, commandHandlerRegistry, flowTypeExtractor, contractNegotiationStore);

@Test
void findById_whenFound() {
Expand Down Expand Up @@ -144,6 +148,7 @@ class InitiateTransfer {
void shouldInitiateTransfer() {
var transferRequest = transferRequest();
var transferProcess = transferProcess();
when(contractNegotiationStore.findContractAgreement(transferRequest.getContractId())).thenReturn(createContractAgreement(transferProcess.getContractId(), transferRequest.getAssetId()));
when(flowTypeExtractor.extract(any())).thenReturn(StatusResult.success(FlowType.PUSH));
when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.success());
when(manager.initiateConsumerRequest(transferRequest)).thenReturn(StatusResult.success(transferProcess));
Expand All @@ -155,7 +160,7 @@ void shouldInitiateTransfer() {
}

@Test
void shouldFail_whenDestinationIsNotValid() {
void shouldFail_whenContractAgreementNotFound() {
when(flowTypeExtractor.extract(any())).thenReturn(StatusResult.success(FlowType.PUSH));
when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.failure(violation("invalid data address", "path")));

Expand All @@ -164,19 +169,57 @@ void shouldFail_whenDestinationIsNotValid() {
assertThat(result).isFailed()
.extracting(ServiceFailure::getReason)
.isEqualTo(BAD_REQUEST);
assertThat(result.getFailureMessages()).containsExactly("Contract agreement with id %s not found".formatted(transferRequest().getContractId()));
verifyNoInteractions(manager);
}

@Test
void shouldFail_whenDataDestinationNotPassedAndFlowTypeIsPush() {
void shouldFail_whenTpAssetIdNotEqualToAgreementAssetId() {
var transferRequest = transferRequest();
var transferProcess = transferProcess();
when(contractNegotiationStore.findContractAgreement(transferRequest.getContractId())).thenReturn(createContractAgreement(transferProcess.getContractId(), "other-asset-id"));
when(flowTypeExtractor.extract(any())).thenReturn(StatusResult.success(FlowType.PUSH));
when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.success());
when(manager.initiateConsumerRequest(transferRequest)).thenReturn(StatusResult.success(transferProcess));

var result = service.initiateTransfer(transferRequest);

assertThat(result).isFailed()
.extracting(ServiceFailure::getReason)
.isEqualTo(BAD_REQUEST);
assertThat(result.getFailureMessages()).containsExactly("Asset id %s in contract agreement does not match asset id in transfer request %s".formatted("other-asset-id", transferRequest.getAssetId()));
verifyNoInteractions(manager);
}

@Test
void shouldFail_whenDestinationIsNotValid() {
var transferRequest = transferRequest();
when(contractNegotiationStore.findContractAgreement(transferRequest.getContractId())).thenReturn(createContractAgreement(transferRequest.getContractId(), transferRequest.getAssetId()));
when(flowTypeExtractor.extract(any())).thenReturn(StatusResult.success(FlowType.PUSH));
var request = TransferRequest.Builder.newInstance()
when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.failure(violation("invalid data address", "path")));

var result = service.initiateTransfer(transferRequest);

assertThat(result).isFailed()
.extracting(ServiceFailure::getReason)
.isEqualTo(BAD_REQUEST);
assertThat(result.getFailureMessages()).containsExactly("invalid data address");
verifyNoInteractions(manager);
}

@Test
void shouldFail_whenDataDestinationNotPassedAndFlowTypeIsPush() {
var transferRequest = TransferRequest.Builder.newInstance()
.transferType("any")
.assetId(UUID.randomUUID().toString())
.build();
when(contractNegotiationStore.findContractAgreement(transferRequest.getContractId())).thenReturn(createContractAgreement(transferRequest.getContractId(), transferRequest.getAssetId()));
when(flowTypeExtractor.extract(any())).thenReturn(StatusResult.success(FlowType.PUSH));

var result = service.initiateTransfer(request);
var result = service.initiateTransfer(transferRequest);

assertThat(result).isFailed().extracting(ServiceFailure::getReason).isEqualTo(BAD_REQUEST);
assertThat(result.getFailureMessages()).containsExactly("For PUSH transfers dataDestination must be defined");
verifyNoInteractions(manager);
}
}
Expand Down Expand Up @@ -312,6 +355,17 @@ private TransferProcess transferProcess(TransferProcessStates state, String id)
private TransferRequest transferRequest() {
return TransferRequest.Builder.newInstance()
.dataDestination(DataAddress.Builder.newInstance().type("type").build())
.assetId(UUID.randomUUID().toString())
.build();
}

private ContractAgreement createContractAgreement(String agreementId, String assetId) {
return ContractAgreement.Builder.newInstance()
.id(agreementId)
.providerId(UUID.randomUUID().toString())
.consumerId(UUID.randomUUID().toString())
.assetId(assetId)
.policy(Policy.Builder.newInstance().build())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,8 @@ public ContractValidationServiceImpl(AssetIndex assetIndex,
@Override
public @NotNull Result<ValidatedConsumerOffer> validateInitialOffer(ParticipantAgent agent, ValidatableConsumerOffer consumerOffer) {
return validateInitialOffer(consumerOffer, agent)
.map(sanitizedPolicy -> {
var offer = createContractOffer(sanitizedPolicy, consumerOffer.getOfferId());
return new ValidatedConsumerOffer(agent.getIdentity(), offer);
});
.compose(policy -> createContractOffer(policy, consumerOffer.getOfferId()))
.map(contractOffer -> new ValidatedConsumerOffer(agent.getIdentity(), contractOffer));
}

@Override
Expand Down Expand Up @@ -163,12 +161,15 @@ private Result<Policy> evaluatePolicy(Policy policy, String scope, ParticipantAg
}

@NotNull
private ContractOffer createContractOffer(Policy policy, ContractOfferId contractOfferId) {
return ContractOffer.Builder.newInstance()
private Result<ContractOffer> createContractOffer(Policy policy, ContractOfferId contractOfferId) {
if (!contractOfferId.assetIdPart().equals(policy.getTarget())) {
return Result.failure("Policy target %s does not match the asset ID in the contract offer %s".formatted(policy.getTarget(), contractOfferId.assetIdPart()));
}
return Result.success(ContractOffer.Builder.newInstance()
.id(contractOfferId.toString())
.policy(policy)
.assetId(contractOfferId.assetIdPart())
.build();
.build());
}

}
Loading

0 comments on commit d931c8b

Please sign in to comment.