Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions packages/client/src/components/TagTraining.component.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@ export const TagTrainingComponent: React.FC<TagTrainingComponentProps> = (props)
const [datasets, setDatasets] = useState<Dataset[]>([]);
const { project } = useProject();
const [getDatasetsQuery, getDatasetsResults] = useGetDatasetsByProjectLazyQuery();
const [trainingSet, setTrainingSet] = useState<string[]>([]);
const [taggingSet, setTaggingSet] = useState<string[]>([]);

useEffect(() => {
if (project) {
getDatasetsQuery({ variables: { project: project._id } });
}
}, [project]);

const trainingSet: Set<string> = new Set();
const fullSet: Set<string> = new Set();

const additionalColumns: GridColDef[] = [
{
field: 'training',
Expand All @@ -35,12 +34,10 @@ export const TagTrainingComponent: React.FC<TagTrainingComponentProps> = (props)
startingValue={false}
onLoad={(_entry) => {}}
add={(entry) => {
trainingSet.add(entry._id);
props.setTrainingSet(Array.from(trainingSet));
setTrainingSet([...trainingSet, entry._id]);
}}
remove={(entry) => {
trainingSet.delete(entry._id);
props.setTrainingSet(Array.from(trainingSet));
setTrainingSet(trainingSet.filter((entryID) => entryID != entry._id));
}}
entry={params.row}
/>
Expand All @@ -52,25 +49,30 @@ export const TagTrainingComponent: React.FC<TagTrainingComponentProps> = (props)
width: 200,
renderCell: (params) => (
<EditSetSwitch
startingValue={true}
onLoad={(entry) => {
fullSet.add(entry._id);
props.setTaggingSet(Array.from(fullSet));
}}
startingValue={false}
onLoad={(_entry) => {}}
add={(entry) => {
fullSet.add(entry._id);
props.setTaggingSet(Array.from(fullSet));
setTaggingSet([...taggingSet, entry._id]);
}}
remove={(entry) => {
fullSet.delete(entry._id);
props.setTaggingSet(Array.from(fullSet));
setTaggingSet(taggingSet.filter((entryID) => entryID != entry._id));
}}
entry={params.row}
/>
)
}
];

useEffect(() => {
const entries = Array.from(new Set(taggingSet));
props.setTaggingSet(entries);
}, [taggingSet]);

useEffect(() => {
const entries = Array.from(new Set(trainingSet));
props.setTrainingSet(entries);
}, [trainingSet]);

// TODO: In the future, the datasets retrieved should only be datasets
// accessible by the current project
useEffect(() => {
Expand Down
9 changes: 9 additions & 0 deletions packages/client/src/graphql/graphql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ export type Mutation = {
createOrganization: Organization;
createStudy: Study;
createTags: Array<Tag>;
createTrainingSet: Scalars['Boolean']['output'];
createUploadSession: UploadSession;
deleteEntry: Scalars['Boolean']['output'];
deleteProject: Scalars['Boolean']['output'];
Expand Down Expand Up @@ -193,6 +194,12 @@ export type MutationCreateTagsArgs = {
};


export type MutationCreateTrainingSetArgs = {
entries: Array<Scalars['ID']['input']>;
study: Scalars['ID']['input'];
};


export type MutationCreateUploadSessionArgs = {
dataset: Scalars['ID']['input'];
};
Expand Down Expand Up @@ -507,6 +514,8 @@ export type Tag = {
/** Way to rank tags based on order to be tagged */
order: Scalars['Float']['output'];
study: Study;
/** If the tag is part of a training */
training: Scalars['Boolean']['output'];
/** The user assigned to the tag */
user?: Maybe<Scalars['String']['output']>;
};
Expand Down
5 changes: 5 additions & 0 deletions packages/client/src/graphql/tag/tag.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ mutation createTags($study: ID!, $entries: [ID!]!) {
}
}

mutation createTrainingSet($study: ID!, $entries: [ID!]!) {
createTrainingSet(study: $study, entries: $entries)
}

mutation setEntryEnabled($study: ID!, $entry: ID!, $enabled: Boolean!) {
setEntryEnabled(study: $study, entry: $entry, enabled: $enabled)
}
Expand Down Expand Up @@ -60,3 +64,4 @@ query getTags($study: ID!) {
complete
}
}

40 changes: 40 additions & 0 deletions packages/client/src/graphql/tag/tag.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ export type CreateTagsMutationVariables = Types.Exact<{

export type CreateTagsMutation = { __typename?: 'Mutation', createTags: Array<{ __typename?: 'Tag', _id: string }> };

export type CreateTrainingSetMutationVariables = Types.Exact<{
study: Types.Scalars['ID']['input'];
entries: Array<Types.Scalars['ID']['input']> | Types.Scalars['ID']['input'];
}>;


export type CreateTrainingSetMutation = { __typename?: 'Mutation', createTrainingSet: boolean };

export type SetEntryEnabledMutationVariables = Types.Exact<{
study: Types.Scalars['ID']['input'];
entry: Types.Scalars['ID']['input'];
Expand Down Expand Up @@ -96,6 +104,38 @@ export function useCreateTagsMutation(baseOptions?: Apollo.MutationHookOptions<C
export type CreateTagsMutationHookResult = ReturnType<typeof useCreateTagsMutation>;
export type CreateTagsMutationResult = Apollo.MutationResult<CreateTagsMutation>;
export type CreateTagsMutationOptions = Apollo.BaseMutationOptions<CreateTagsMutation, CreateTagsMutationVariables>;
export const CreateTrainingSetDocument = gql`
mutation createTrainingSet($study: ID!, $entries: [ID!]!) {
createTrainingSet(study: $study, entries: $entries)
}
`;
export type CreateTrainingSetMutationFn = Apollo.MutationFunction<CreateTrainingSetMutation, CreateTrainingSetMutationVariables>;

/**
* __useCreateTrainingSetMutation__
*
* To run a mutation, you first call `useCreateTrainingSetMutation` within a React component and pass it any options that fit your needs.
* When your component renders, `useCreateTrainingSetMutation` returns a tuple that includes:
* - A mutate function that you can call at any time to execute the mutation
* - An object with fields that represent the current status of the mutation's execution
*
* @param baseOptions options that will be passed into the mutation, supported options are listed on: https://www.apollographql.com/docs/react/api/react-hooks/#options-2;
*
* @example
* const [createTrainingSetMutation, { data, loading, error }] = useCreateTrainingSetMutation({
* variables: {
* study: // value for 'study'
* entries: // value for 'entries'
* },
* });
*/
export function useCreateTrainingSetMutation(baseOptions?: Apollo.MutationHookOptions<CreateTrainingSetMutation, CreateTrainingSetMutationVariables>) {
const options = {...defaultOptions, ...baseOptions}
return Apollo.useMutation<CreateTrainingSetMutation, CreateTrainingSetMutationVariables>(CreateTrainingSetDocument, options);
}
export type CreateTrainingSetMutationHookResult = ReturnType<typeof useCreateTrainingSetMutation>;
export type CreateTrainingSetMutationResult = Apollo.MutationResult<CreateTrainingSetMutation>;
export type CreateTrainingSetMutationOptions = Apollo.BaseMutationOptions<CreateTrainingSetMutation, CreateTrainingSetMutationVariables>;
export const SetEntryEnabledDocument = gql`
mutation setEntryEnabled($study: ID!, $entry: ID!, $enabled: Boolean!) {
setEntryEnabled(study: $study, entry: $entry, enabled: $enabled)
Expand Down
26 changes: 21 additions & 5 deletions packages/client/src/pages/studies/NewStudy.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@ import { TagTrainingComponent } from '../../components/TagTraining.component';
import { useState, useEffect } from 'react';
import { StudyCreate, TagSchema } from '../../graphql/graphql';
import { PartialStudyCreate } from '../../types/study';
import { CreateStudyDocument } from '../../graphql/study/study';
import { CreateStudyDocument, CreateStudyMutation, CreateStudyMutationVariables } from '../../graphql/study/study';
import { useProject } from '../../context/Project.context';
import { useStudy } from '../../context/Study.context';
import { useApolloClient } from '@apollo/client';
import { CreateTagsDocument } from '../../graphql/tag/tag';
import {
CreateTagsDocument,
CreateTrainingSetDocument,
CreateTagsMutationVariables,
CreateTagsMutation,
CreateTrainingSetMutation,
CreateTrainingSetMutationVariables
} from '../../graphql/tag/tag';
import { useTranslation } from 'react-i18next';
import { TagFieldFragmentSchema, TagField } from '../../components/tagbuilder/TagProvider';

Expand All @@ -20,7 +27,7 @@ export const NewStudy: React.FC = () => {
const [tagSchema, setTagSchema] = useState<TagSchema | null>(null);
const { project } = useProject();
const { updateStudies } = useStudy();
const [_trainingSet, setTrainingSet] = useState<string[]>([]);
const [trainingSet, setTrainingSet] = useState<string[]>([]);
const [taggingSet, setTaggingSet] = useState<string[]>([]);
const apolloClient = useApolloClient();
// The different fields that make up the tag schema
Expand Down Expand Up @@ -70,7 +77,7 @@ export const NewStudy: React.FC = () => {
};

// Make the new study
const result = await apolloClient.mutate({
const result = await apolloClient.mutate<CreateStudyMutation, CreateStudyMutationVariables>({
mutation: CreateStudyDocument,
variables: { study: study }
});
Expand All @@ -81,10 +88,19 @@ export const NewStudy: React.FC = () => {
}

// Create the corresponding tags
await apolloClient.mutate({
await apolloClient.mutate<CreateTagsMutation, CreateTagsMutationVariables>({
mutation: CreateTagsDocument,
variables: { study: result.data.createStudy._id, entries: taggingSet }
});

// Create the training set
await apolloClient.mutate<CreateTrainingSetMutation, CreateTrainingSetMutationVariables>({
mutation: CreateTrainingSetDocument,
variables: {
study: result.data.createStudy._id,
entries: trainingSet
}
});
updateStudies();
}
setActiveStep((prevActiveStep: number) => prevActiveStep + 1);
Expand Down
4 changes: 4 additions & 0 deletions packages/server/src/tag/models/tag.model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ export class Tag {
@Prop()
@Field({ description: 'If the tag is enabled as part of the study, way to disable certain tags' })
enabled: boolean;

@Prop()
@Field({ description: 'If the tag is part of a training' })
training: boolean;
}

export type TagDocument = Tag & Document;
Expand Down
14 changes: 14 additions & 0 deletions packages/server/src/tag/models/training-set.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import { Schema, Prop, SchemaFactory } from '@nestjs/mongoose';
import { Document } from 'mongoose';

@Schema()
export class TrainingSet {
@Prop()
study: string;

@Prop()
entries: string[];
}

export type TrainingSetDocument = TrainingSet & Document;
export const TrainingSetSchema = SchemaFactory.createForClass(TrainingSet);
35 changes: 35 additions & 0 deletions packages/server/src/tag/resolvers/training-set.resolver.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { Resolver, Mutation, Args, ID } from '@nestjs/graphql';
import { Entry } from '../../entry/models/entry.model';
import { EntriesPipe } from '../../entry/pipes/entry.pipe';
import { TokenPayload } from '../../jwt/token.dto';
import { TokenContext } from '../../jwt/token.context';
import { StudyPipe } from '../../study/pipes/study.pipe';
import { Study } from '../../study/study.model';
import { TrainingSetService } from '../services/training-set.service';
import { Inject, UnauthorizedException, UseGuards } from '@nestjs/common';
import { CASBIN_PROVIDER } from 'src/permission/casbin.provider';
import * as casbin from 'casbin';
import { StudyPermissions } from 'src/permission/permissions/study';
import { JwtAuthGuard } from '../../jwt/jwt.guard';

@UseGuards(JwtAuthGuard)
@Resolver()
export class TrainingSetResolver {
constructor(
private readonly trainingService: TrainingSetService,
@Inject(CASBIN_PROVIDER) private readonly enforcer: casbin.Enforcer
) {}

@Mutation(() => Boolean)
async createTrainingSet(
@Args('study', { type: () => ID }, StudyPipe) study: Study,
@Args('entries', { type: () => [ID] }, EntriesPipe) entries: Entry[],
@TokenContext() user: TokenPayload
): Promise<boolean> {
if (!(await this.enforcer.enforce(user.user_id, StudyPermissions.CREATE, study._id.toString()))) {
throw new UnauthorizedException('User cannot create a training set for this study');
}
await this.trainingService.create(study, entries);
return true;
}
}
18 changes: 18 additions & 0 deletions packages/server/src/tag/services/training-set.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { Injectable } from '@nestjs/common';
import { InjectModel } from '@nestjs/mongoose';
import { Model } from 'mongoose';
import { Entry } from '../../entry/models/entry.model';
import { Study } from '../../study/study.model';
import { TrainingSet } from '../models/training-set';

@Injectable()
export class TrainingSetService {
constructor(@InjectModel(TrainingSet.name) private readonly trainingSetModel: Model<TrainingSet>) {}

async create(study: Study, entries: Entry[]): Promise<TrainingSet> {
return this.trainingSetModel.create({
study: study._id,
entries: entries.map((entry) => entry._id)
});
}
}
10 changes: 8 additions & 2 deletions packages/server/src/tag/tag.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ import { TagTransformer } from './services/tag-transformer.service';
import { FieldTransformerFactory } from './transformers/field-transformer-factory';
import { VideoFieldTransformer } from './transformers/video-field-transformer';
import { DatasetModule } from '../dataset/dataset.module';
import { TrainingSet, TrainingSetSchema } from './models/training-set';
import { TrainingSetResolver } from './resolvers/training-set.resolver';
import { TrainingSetService } from './services/training-set.service';

@Module({
imports: [
MongooseModule.forFeature([
{ name: Tag.name, schema: TagSchema },
{ name: VideoField.name, schema: VideoFieldSchema }
{ name: VideoField.name, schema: VideoFieldSchema },
{ name: TrainingSet.name, schema: TrainingSetSchema }
]),
StudyModule,
EntryModule,
Expand All @@ -38,7 +42,9 @@ import { DatasetModule } from '../dataset/dataset.module';
VideoFieldResolver,
TagTransformer,
FieldTransformerFactory,
VideoFieldTransformer
VideoFieldTransformer,
TrainingSetResolver,
TrainingSetService
]
})
export class TagModule {}