diff --git a/packages/client/src/components/TagTraining.component.tsx b/packages/client/src/components/TagTraining.component.tsx index 3c877d99..a8f0487e 100644 --- a/packages/client/src/components/TagTraining.component.tsx +++ b/packages/client/src/components/TagTraining.component.tsx @@ -15,6 +15,8 @@ export const TagTrainingComponent: React.FC = (props) const [datasets, setDatasets] = useState([]); const { project } = useProject(); const [getDatasetsQuery, getDatasetsResults] = useGetDatasetsByProjectLazyQuery(); + const [trainingSet, setTrainingSet] = useState([]); + const [taggingSet, setTaggingSet] = useState([]); useEffect(() => { if (project) { @@ -22,9 +24,6 @@ export const TagTrainingComponent: React.FC = (props) } }, [project]); - const trainingSet: Set = new Set(); - const fullSet: Set = new Set(); - const additionalColumns: GridColDef[] = [ { field: 'training', @@ -35,12 +34,10 @@ export const TagTrainingComponent: React.FC = (props) startingValue={false} onLoad={(_entry) => {}} add={(entry) => { - trainingSet.add(entry._id); - props.setTrainingSet(Array.from(trainingSet)); + setTrainingSet([...trainingSet, entry._id]); }} remove={(entry) => { - trainingSet.delete(entry._id); - props.setTrainingSet(Array.from(trainingSet)); + setTrainingSet(trainingSet.filter((entryID) => entryID != entry._id)); }} entry={params.row} /> @@ -52,18 +49,13 @@ export const TagTrainingComponent: React.FC = (props) width: 200, renderCell: (params) => ( { - fullSet.add(entry._id); - props.setTaggingSet(Array.from(fullSet)); - }} + startingValue={false} + onLoad={(_entry) => {}} add={(entry) => { - fullSet.add(entry._id); - props.setTaggingSet(Array.from(fullSet)); + setTaggingSet([...taggingSet, entry._id]); }} remove={(entry) => { - fullSet.delete(entry._id); - props.setTaggingSet(Array.from(fullSet)); + setTaggingSet(taggingSet.filter((entryID) => entryID != entry._id)); }} entry={params.row} /> @@ -71,6 +63,16 @@ export const TagTrainingComponent: React.FC = (props) } ]; + useEffect(() => { + const entries = Array.from(new Set(taggingSet)); + props.setTaggingSet(entries); + }, [taggingSet]); + + useEffect(() => { + const entries = Array.from(new Set(trainingSet)); + props.setTrainingSet(entries); + }, [trainingSet]); + // TODO: In the future, the datasets retrieved should only be datasets // accessible by the current project useEffect(() => { diff --git a/packages/client/src/graphql/graphql.ts b/packages/client/src/graphql/graphql.ts index e8007aa6..486a8f49 100644 --- a/packages/client/src/graphql/graphql.ts +++ b/packages/client/src/graphql/graphql.ts @@ -112,6 +112,7 @@ export type Mutation = { createOrganization: Organization; createStudy: Study; createTags: Array; + createTrainingSet: Scalars['Boolean']['output']; createUploadSession: UploadSession; deleteEntry: Scalars['Boolean']['output']; deleteProject: Scalars['Boolean']['output']; @@ -193,6 +194,12 @@ export type MutationCreateTagsArgs = { }; +export type MutationCreateTrainingSetArgs = { + entries: Array; + study: Scalars['ID']['input']; +}; + + export type MutationCreateUploadSessionArgs = { dataset: Scalars['ID']['input']; }; @@ -507,6 +514,8 @@ export type Tag = { /** Way to rank tags based on order to be tagged */ order: Scalars['Float']['output']; study: Study; + /** If the tag is part of a training */ + training: Scalars['Boolean']['output']; /** The user assigned to the tag */ user?: Maybe; }; diff --git a/packages/client/src/graphql/tag/tag.graphql b/packages/client/src/graphql/tag/tag.graphql index d30a7514..2efa1006 100644 --- a/packages/client/src/graphql/tag/tag.graphql +++ b/packages/client/src/graphql/tag/tag.graphql @@ -4,6 +4,10 @@ mutation createTags($study: ID!, $entries: [ID!]!) { } } +mutation createTrainingSet($study: ID!, $entries: [ID!]!) { + createTrainingSet(study: $study, entries: $entries) +} + mutation setEntryEnabled($study: ID!, $entry: ID!, $enabled: Boolean!) { setEntryEnabled(study: $study, entry: $entry, enabled: $enabled) } @@ -60,3 +64,4 @@ query getTags($study: ID!) { complete } } + diff --git a/packages/client/src/graphql/tag/tag.ts b/packages/client/src/graphql/tag/tag.ts index d3c4b3b5..6417d986 100644 --- a/packages/client/src/graphql/tag/tag.ts +++ b/packages/client/src/graphql/tag/tag.ts @@ -13,6 +13,14 @@ export type CreateTagsMutationVariables = Types.Exact<{ export type CreateTagsMutation = { __typename?: 'Mutation', createTags: Array<{ __typename?: 'Tag', _id: string }> }; +export type CreateTrainingSetMutationVariables = Types.Exact<{ + study: Types.Scalars['ID']['input']; + entries: Array | Types.Scalars['ID']['input']; +}>; + + +export type CreateTrainingSetMutation = { __typename?: 'Mutation', createTrainingSet: boolean }; + export type SetEntryEnabledMutationVariables = Types.Exact<{ study: Types.Scalars['ID']['input']; entry: Types.Scalars['ID']['input']; @@ -96,6 +104,38 @@ export function useCreateTagsMutation(baseOptions?: Apollo.MutationHookOptions; export type CreateTagsMutationResult = Apollo.MutationResult; export type CreateTagsMutationOptions = Apollo.BaseMutationOptions; +export const CreateTrainingSetDocument = gql` + mutation createTrainingSet($study: ID!, $entries: [ID!]!) { + createTrainingSet(study: $study, entries: $entries) +} + `; +export type CreateTrainingSetMutationFn = Apollo.MutationFunction; + +/** + * __useCreateTrainingSetMutation__ + * + * To run a mutation, you first call `useCreateTrainingSetMutation` within a React component and pass it any options that fit your needs. + * When your component renders, `useCreateTrainingSetMutation` returns a tuple that includes: + * - A mutate function that you can call at any time to execute the mutation + * - An object with fields that represent the current status of the mutation's execution + * + * @param baseOptions options that will be passed into the mutation, supported options are listed on: https://www.apollographql.com/docs/react/api/react-hooks/#options-2; + * + * @example + * const [createTrainingSetMutation, { data, loading, error }] = useCreateTrainingSetMutation({ + * variables: { + * study: // value for 'study' + * entries: // value for 'entries' + * }, + * }); + */ +export function useCreateTrainingSetMutation(baseOptions?: Apollo.MutationHookOptions) { + const options = {...defaultOptions, ...baseOptions} + return Apollo.useMutation(CreateTrainingSetDocument, options); + } +export type CreateTrainingSetMutationHookResult = ReturnType; +export type CreateTrainingSetMutationResult = Apollo.MutationResult; +export type CreateTrainingSetMutationOptions = Apollo.BaseMutationOptions; export const SetEntryEnabledDocument = gql` mutation setEntryEnabled($study: ID!, $entry: ID!, $enabled: Boolean!) { setEntryEnabled(study: $study, entry: $entry, enabled: $enabled) diff --git a/packages/client/src/pages/studies/NewStudy.tsx b/packages/client/src/pages/studies/NewStudy.tsx index 7459b73a..24023038 100644 --- a/packages/client/src/pages/studies/NewStudy.tsx +++ b/packages/client/src/pages/studies/NewStudy.tsx @@ -5,11 +5,18 @@ import { TagTrainingComponent } from '../../components/TagTraining.component'; import { useState, useEffect } from 'react'; import { StudyCreate, TagSchema } from '../../graphql/graphql'; import { PartialStudyCreate } from '../../types/study'; -import { CreateStudyDocument } from '../../graphql/study/study'; +import { CreateStudyDocument, CreateStudyMutation, CreateStudyMutationVariables } from '../../graphql/study/study'; import { useProject } from '../../context/Project.context'; import { useStudy } from '../../context/Study.context'; import { useApolloClient } from '@apollo/client'; -import { CreateTagsDocument } from '../../graphql/tag/tag'; +import { + CreateTagsDocument, + CreateTrainingSetDocument, + CreateTagsMutationVariables, + CreateTagsMutation, + CreateTrainingSetMutation, + CreateTrainingSetMutationVariables +} from '../../graphql/tag/tag'; import { useTranslation } from 'react-i18next'; import { TagFieldFragmentSchema, TagField } from '../../components/tagbuilder/TagProvider'; @@ -20,7 +27,7 @@ export const NewStudy: React.FC = () => { const [tagSchema, setTagSchema] = useState(null); const { project } = useProject(); const { updateStudies } = useStudy(); - const [_trainingSet, setTrainingSet] = useState([]); + const [trainingSet, setTrainingSet] = useState([]); const [taggingSet, setTaggingSet] = useState([]); const apolloClient = useApolloClient(); // The different fields that make up the tag schema @@ -70,7 +77,7 @@ export const NewStudy: React.FC = () => { }; // Make the new study - const result = await apolloClient.mutate({ + const result = await apolloClient.mutate({ mutation: CreateStudyDocument, variables: { study: study } }); @@ -81,10 +88,19 @@ export const NewStudy: React.FC = () => { } // Create the corresponding tags - await apolloClient.mutate({ + await apolloClient.mutate({ mutation: CreateTagsDocument, variables: { study: result.data.createStudy._id, entries: taggingSet } }); + + // Create the training set + await apolloClient.mutate({ + mutation: CreateTrainingSetDocument, + variables: { + study: result.data.createStudy._id, + entries: trainingSet + } + }); updateStudies(); } setActiveStep((prevActiveStep: number) => prevActiveStep + 1); diff --git a/packages/server/src/tag/models/tag.model.ts b/packages/server/src/tag/models/tag.model.ts index d5da7bfb..78408a0c 100644 --- a/packages/server/src/tag/models/tag.model.ts +++ b/packages/server/src/tag/models/tag.model.ts @@ -41,6 +41,10 @@ export class Tag { @Prop() @Field({ description: 'If the tag is enabled as part of the study, way to disable certain tags' }) enabled: boolean; + + @Prop() + @Field({ description: 'If the tag is part of a training' }) + training: boolean; } export type TagDocument = Tag & Document; diff --git a/packages/server/src/tag/models/training-set.ts b/packages/server/src/tag/models/training-set.ts new file mode 100644 index 00000000..b2cf145c --- /dev/null +++ b/packages/server/src/tag/models/training-set.ts @@ -0,0 +1,14 @@ +import { Schema, Prop, SchemaFactory } from '@nestjs/mongoose'; +import { Document } from 'mongoose'; + +@Schema() +export class TrainingSet { + @Prop() + study: string; + + @Prop() + entries: string[]; +} + +export type TrainingSetDocument = TrainingSet & Document; +export const TrainingSetSchema = SchemaFactory.createForClass(TrainingSet); diff --git a/packages/server/src/tag/resolvers/training-set.resolver.ts b/packages/server/src/tag/resolvers/training-set.resolver.ts new file mode 100644 index 00000000..a8cee47b --- /dev/null +++ b/packages/server/src/tag/resolvers/training-set.resolver.ts @@ -0,0 +1,35 @@ +import { Resolver, Mutation, Args, ID } from '@nestjs/graphql'; +import { Entry } from '../../entry/models/entry.model'; +import { EntriesPipe } from '../../entry/pipes/entry.pipe'; +import { TokenPayload } from '../../jwt/token.dto'; +import { TokenContext } from '../../jwt/token.context'; +import { StudyPipe } from '../../study/pipes/study.pipe'; +import { Study } from '../../study/study.model'; +import { TrainingSetService } from '../services/training-set.service'; +import { Inject, UnauthorizedException, UseGuards } from '@nestjs/common'; +import { CASBIN_PROVIDER } from 'src/permission/casbin.provider'; +import * as casbin from 'casbin'; +import { StudyPermissions } from 'src/permission/permissions/study'; +import { JwtAuthGuard } from '../../jwt/jwt.guard'; + +@UseGuards(JwtAuthGuard) +@Resolver() +export class TrainingSetResolver { + constructor( + private readonly trainingService: TrainingSetService, + @Inject(CASBIN_PROVIDER) private readonly enforcer: casbin.Enforcer + ) {} + + @Mutation(() => Boolean) + async createTrainingSet( + @Args('study', { type: () => ID }, StudyPipe) study: Study, + @Args('entries', { type: () => [ID] }, EntriesPipe) entries: Entry[], + @TokenContext() user: TokenPayload + ): Promise { + if (!(await this.enforcer.enforce(user.user_id, StudyPermissions.CREATE, study._id.toString()))) { + throw new UnauthorizedException('User cannot create a training set for this study'); + } + await this.trainingService.create(study, entries); + return true; + } +} diff --git a/packages/server/src/tag/services/training-set.service.ts b/packages/server/src/tag/services/training-set.service.ts new file mode 100644 index 00000000..774a0531 --- /dev/null +++ b/packages/server/src/tag/services/training-set.service.ts @@ -0,0 +1,18 @@ +import { Injectable } from '@nestjs/common'; +import { InjectModel } from '@nestjs/mongoose'; +import { Model } from 'mongoose'; +import { Entry } from '../../entry/models/entry.model'; +import { Study } from '../../study/study.model'; +import { TrainingSet } from '../models/training-set'; + +@Injectable() +export class TrainingSetService { + constructor(@InjectModel(TrainingSet.name) private readonly trainingSetModel: Model) {} + + async create(study: Study, entries: Entry[]): Promise { + return this.trainingSetModel.create({ + study: study._id, + entries: entries.map((entry) => entry._id) + }); + } +} diff --git a/packages/server/src/tag/tag.module.ts b/packages/server/src/tag/tag.module.ts index 8885edaa..0d5e81a7 100644 --- a/packages/server/src/tag/tag.module.ts +++ b/packages/server/src/tag/tag.module.ts @@ -16,12 +16,16 @@ import { TagTransformer } from './services/tag-transformer.service'; import { FieldTransformerFactory } from './transformers/field-transformer-factory'; import { VideoFieldTransformer } from './transformers/video-field-transformer'; import { DatasetModule } from '../dataset/dataset.module'; +import { TrainingSet, TrainingSetSchema } from './models/training-set'; +import { TrainingSetResolver } from './resolvers/training-set.resolver'; +import { TrainingSetService } from './services/training-set.service'; @Module({ imports: [ MongooseModule.forFeature([ { name: Tag.name, schema: TagSchema }, - { name: VideoField.name, schema: VideoFieldSchema } + { name: VideoField.name, schema: VideoFieldSchema }, + { name: TrainingSet.name, schema: TrainingSetSchema } ]), StudyModule, EntryModule, @@ -38,7 +42,9 @@ import { DatasetModule } from '../dataset/dataset.module'; VideoFieldResolver, TagTransformer, FieldTransformerFactory, - VideoFieldTransformer + VideoFieldTransformer, + TrainingSetResolver, + TrainingSetService ] }) export class TagModule {}