DL4J: Как вычислить косинусное сходство между INDArray, полученным из getWordVectorsMean

Я рассчитал VectorMean из двух предложений следующим образом:

String demoString1 = "Enter first label";
String demoString2 = "Enter first name";
        Collection<String> label1 = Splitter.on(' ').splitToList(demoString1);
        Collection<String> label2 = Splitter.on(' ').splitToList(demoString2);

        System.out.println("label1:==>"+label1);
        System.out.println("getWordVectorMatrix->INDArray------------------"+vectors.getWordVectorsMean(label1));

        System.out.println("label2:==>"+label2);
        System.out.println("getWordVectorMatrix->INDArray------------------"+vectors.getWordVectorsMean(label2));

Вывод:

label1:==>[Enter, first, label]
getWordVectorMatrix->INDArray------------------[0.02,  -0.14,  0.07,  -0.10,.............100 dimension vector]
label2:==>[Enter, first, name]
getWordVectorMatrix->INDArray------------------[-0.00,  -0.15,  0.07,  -0.13,............100 dimension vector]

Теперь, как я могу вычислить сходство (косинусное сходство) между обоими предложениями, используя их среднее значение? Я искал, но не смог найти никакого API, доступного в DL4J.


person Om Sao    schedule 06.02.2018    source источник


Ответы (1)


Метод:

public static double cosineSimForSentence(Word2Vec vector, String sentence1, String sentence2){
        Collection<String> label1 = Splitter.on(' ').splitToList(sentence1);
        Collection<String> label2 = Splitter.on(' ').splitToList(sentence2);
        try{
            return Transforms.cosineSim(vector.getWordVectorsMean(label1), vector.getWordVectorsMean(label2));
        }catch(Exception e){
            exceptionMessage = e.getMessage();
        }
        return Transforms.cosineSim(vector.getWordVectorsMean(label1), vector.getWordVectorsMean(label2));

    }

Вызов метода:

System.out.println("Similarity Score between: "+demoString1+" --vs-- "+ demoString2 +":==>"+ cosineSimForSentence(vectors, demoString1, demoString2));
person Om Sao    schedule 04.04.2018