Преобразование RDD в LabeledPoint

Если у меня есть RDD с примерно 500 столбцами и 200 миллионами строк, а RDD.columns.indexOf("target", 0) показывает Int = 77, который говорит мне, что моя целевая зависимая переменная находится в столбце номер 77. Но у меня недостаточно знаний о том, как выбрать нужные (частичные) столбцы в качестве функций (скажем, мне нужны столбцы с 23 по 59, со 111 по 357, с 399 по 489). Мне интересно, могу ли я применить такие:

val data = rdd.map(col => new LabeledPoint(
    col(77).toDouble, Vectors.dense(??.map(x => x.toDouble).toArray))

Любые предложения или рекомендации будут высоко оценены.

Может быть, я перепутал RDD с DataFrame, я могу преобразовать RDD в DataFrame с помощью .toDF() или проще достичь цели с DataFrame, чем с RDD.


person Richard Liu    schedule 26.07.2015    source источник


Ответы (1)


Я предполагаю, что ваши данные выглядят примерно так:

import scala.util.Random.{setSeed, nextDouble}
setSeed(1)

case class Record(
    foo: Double, target: Double, x1: Double, x2: Double, x3: Double)

val rows = sc.parallelize(
    (1 to 10).map(_ => Record(
        nextDouble, nextDouble, nextDouble, nextDouble, nextDouble
   ))
)
val df = sqlContext.createDataFrame(rows)
df.registerTempTable("df")

sqlContext.sql("""
  SELECT ROUND(foo, 2) foo,
         ROUND(target, 2) target,
         ROUND(x1, 2) x1,
         ROUND(x2, 2) x2,
         ROUND(x2, 2) x3 
  FROM df""").show

Итак, у нас есть данные, как показано ниже:

+----+------+----+----+----+
| foo|target|  x1|  x2|  x3|
+----+------+----+----+----+
|0.73|  0.41|0.21|0.33|0.33|
|0.01|  0.96|0.94|0.95|0.95|
| 0.4|  0.35|0.29|0.51|0.51|
|0.77|  0.66|0.16|0.38|0.38|
|0.69|  0.81|0.01|0.52|0.52|
|0.14|  0.48|0.54|0.58|0.58|
|0.62|  0.18|0.01|0.16|0.16|
|0.54|  0.97|0.25|0.39|0.39|
|0.43|  0.23|0.89|0.04|0.04|
|0.66|  0.12|0.65|0.98|0.98|
+----+------+----+----+----+

и мы хотим игнорировать foo и x2 и извлечь LabeledPoint(target, Array(x1, x3)):

// Map feature names to indices
val featInd = List("x1", "x3").map(df.columns.indexOf(_))

// Or if you want to exclude columns
val ignored = List("foo", "target", "x2")
val featInd = df.columns.diff(ignored).map(df.columns.indexOf(_))

// Get index of target
val targetInd = df.columns.indexOf("target") 

df.rdd.map(r => LabeledPoint(
   r.getDouble(targetInd), // Get target value
   // Map feature indices to values
   Vectors.dense(featInd.map(r.getDouble(_)).toArray) 
))
person zero323    schedule 26.07.2015
comment
Отличный код! и это работает очень хорошо. Я только что сделал небольшую модификацию для опечатки val targetInd = df.columns.indexOf("target") - person Richard Liu; 26.07.2015
comment
действительно ценю. Есть ли быстрый способ исключить функцию из списка в вашем примере? Скажем, у вас есть val featInd = List("x1", "x3).map...., что, если у меня есть 200 функций, которые мне нужны, и я подавляю только 3 из них? Что-то вроде val featInd = De-List("x2").map....? - person Richard Liu; 26.07.2015
comment
Конечно, вы можете использовать filter или diff. Я добавил пример. - person zero323; 26.07.2015
comment
Ты супер!! diff() это! - person Richard Liu; 26.07.2015
comment
Надеюсь, вы не возражаете, если я добавлю еще один вопрос... что, если исходный rdd через JDBC, и все они имеют тип данных Decimal? Кажется, что LabledPoint принимает только Double. Любой быстрый способ преобразовать столбцы из десятичного числа в двойное? - person Richard Liu; 27.07.2015
comment
Это java.math.BigDecimal, а не scala.math.BigDecimal, верно? Если это так, вы можете заменить r.getDouble(_) на r.getDecimal(_).floatValue.toDouble. - person zero323; 27.07.2015
comment
Давайте продолжим обсуждение в чате. - person Richard Liu; 27.07.2015
comment
У меня есть существующий DF, который состоит из всех строковых полей. Конечно, я получаю ошибки при преобразовании labeledpoint. Мне нужно изменить эту строку, чтобы преобразовать строки в вектор признаков. Должен ли я сначала указать все поля или есть способ добавить это? - person Jimmy Hendricks; 13.10.2016
comment
Любой способ преобразовать r.getDouble(_) во что-то вроде r.getDouble(_)>0 ? 1:0 Вероятно, это ближе к java, но вы поняли. - person Dylan_Larkin; 13.11.2017