diff --git a/main.py b/main.py index e118185..efd2e24 100644 --- a/main.py +++ b/main.py @@ -167,7 +167,6 @@ if __name__ == "__main__": parser_audio_amp = subparsers.add_parser('amplitude', help='The amplitude editor uses audio amplitude moving averages to find swings from relatively quiet moments to loud moments. This is useful in videos where long moments of quiet are interspersed with loud action filled moments.') parser_audio_amp.add_argument( - "-f", "--factor", default=16000, help="Subsampling factor", @@ -175,7 +174,14 @@ if __name__ == "__main__": type=int, ) - parser_audio_amp = subparsers.add_parser('sentiment', help='The sentiment editor transcribes the speech in a video and runs sentiment analysis on the resulting text. Using moving averages, large swings in sentiment can be correlated to controversial or exciting moments. A GPU with CUDA is recommended for fast results.') + parser_sentiment = subparsers.add_parser('sentiment', help='The sentiment editor transcribes the speech in a video and runs sentiment analysis on the resulting text. Using moving averages, large swings in sentiment can be correlated to controversial or exciting moments. A GPU with CUDA is recommended for fast results.') + parser_sentiment.add_argument( + "--model", + default="base", + help="The size of the sentiment analysis model being used. Larger models increase computation time.", + dest="model_size", + choices=["base", "tiny", "small", "medium", "large"], + ) parser.add_argument("-p", "--parallelism", dest="parallelism", type=int, default=multiprocessing.cpu_count() - 2, help="The number of cores to use, defaults to N - 2 cores.") parser.add_argument("--cost-function", dest="cost", choices=ERROR_FUNCS.keys(), default='quadratic') diff --git a/src/editors/amplitude/editor.py b/src/editors/amplitude/editor.py index 6b9d19b..407bfe2 100644 --- a/src/editors/amplitude/editor.py +++ b/src/editors/amplitude/editor.py @@ -1,6 +1,5 @@ from ...mediautils.audio import process_audio, resample from ...math.average import np_moving_average -from ...models.moment import Moment from ..common import find_moving_average_highlights import numpy as np import structlog diff --git a/src/editors/sentiment/editor.py b/src/editors/sentiment/editor.py index 903d7d5..9c32ba0 100644 --- a/src/editors/sentiment/editor.py +++ b/src/editors/sentiment/editor.py @@ -1,3 +1,48 @@ +import whisper +import numpy as np +import structlog + +from flair.models import TextClassifier +from dataclasses import dataclass +from flair.data import Sentence + +from ...math.average import np_moving_average +from ..common import find_moving_average_highlights + +@dataclass +class TextGlob: + start:float + stop:float + text:str + sentiment:float class SentimentEditor: - pass + def __init__(self, video_path, audio_path, params): + self.logger = structlog.get_logger("sentiment") + self.logger.info("loading whisper model", size=params["model_size"]) + self.model = whisper.load_model(params["model_size"]) + self.logger.info("transcribing audio", path=audio_path) + self.result = self.model.transcribe(audio_path) + self.segments = [] + for segment in self.result['segments']: + self.segments.append(TextGlob(segment['start'], segment['end'], segment['text'], 0)) + classifier = TextClassifier.load('en-sentiment') + self.sentiments = [] + self.logger.info("calculating sentiment on segments", segments=len(self.segments)) + for segment in self.segments: + sentence = Sentence(segment.text) + classifier.predict(sentence) + sentsum = sentence.labels[0].score + if sentence.labels[0].value == "NEGATIVE": + sentsum = sentsum * -1 + segment.sentiment = sentsum + self.sentiments.append(sentsum) + self.sentiments = np.array(self.sentiments) + + def edit(self, large_window, small_window, params): + end_time = self.segments[-1].stop + window_factor = len(self.sentiments) / end_time + long_ma = np_moving_average(self.squared_subsample, large_window * window_factor) + short_ma = np_moving_average(self.squared_subsample, small_window * window_factor) + highlights = find_moving_average_highlights(short_ma, long_ma, 1.0 / window_factor) + return highlights, large_window, small_window