Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from Orange.widgets.widget import OWWidget, Msg, Output, Input
from Orange.widgets.utils.itemmodels import DomainModel
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.data import Table, Domain, DiscreteVariable, StringVariable
from Orange.data import \
Table, Domain, DiscreteVariable, StringVariable, ContinuousVariable
from Orange.data.util import SharedComputeValue, get_unique_names

from orangewidget.settings import Setting
Expand All @@ -20,13 +21,25 @@ def get_substrings(values, delimiter):
- {""})


class SplitColumn:
class SplitColumnBase:
def __init__(self, data, attr, delimiter):
self.attr = attr
self.delimiter = delimiter
column = set(data.get_column(self.attr))
self.new_values = tuple(get_substrings(column, self.delimiter))

def __eq__(self, other):
return self.attr == other.attr \
and self.delimiter == other.delimiter \
and self.new_values == other.new_values

def __hash__(self):
return hash((self.attr, self.delimiter, self.new_values))


class SplitColumnOneHot(SplitColumnBase):
InheritEq = True

def __call__(self, data):
column = data.get_column(self.attr)
values = [{ss.strip() for ss in s.split(self.delimiter)}
Expand All @@ -35,62 +48,83 @@ def __call__(self, data):
dtype=int)
for v in self.new_values}

def __eq__(self, other):
return self.attr == other.attr \
and self.delimiter == other.delimiter \
and self.new_values == other.new_values

def __hash__(self):
return hash((self.attr, self.delimiter, self.new_values))
class SplitColumnCounts(SplitColumnBase):
InheritEq = True

def __call__(self, data):
column = data.get_column(self.attr)
values = [[ss.strip() for ss in s.split(self.delimiter)]
for s in column]
return {v: np.array([xs.count(v) for xs in values], dtype=float)
for v in self.new_values}


class OneHotStrings(SharedComputeValue):
class StringEncodingBase(SharedComputeValue):
def __init__(self, fn, new_feature):
super().__init__(fn)
self.new_feature = new_feature

def __eq__(self, other):
return super().__eq__(other) and self.new_feature == other.new_feature

def __hash__(self):
return super().__hash__() ^ hash(self.new_feature)


class OneHotStrings(StringEncodingBase):
InheritEq = True

def compute(self, data, shared_data):
indices = shared_data[self.new_feature]
col = np.zeros(len(data))
col[indices] = 1
return col

def __eq__(self, other):
return super().__eq__(other) and self.new_feature == other.new_feature

def __hash__(self):
return super().__hash__() ^ hash(self.new_feature)
class CountStrings(StringEncodingBase):
InheritEq = True

def compute(self, data, shared_data):
return shared_data[self.new_feature]


class OneHotDiscrete:
def __init__(self, variable, delimiter, value):
class DiscreteEncoding:
def __init__(self, variable, delimiter, onehot, value):
self.variable = variable
self.value = value
self.delimiter = delimiter
self.onehot = onehot
self.value = value

def __call__(self, data):
column = data.get_column(self.variable).astype(float)
col = np.zeros(len(column))
col[np.isnan(column)] = np.nan
for val_idx, value in enumerate(self.variable.values):
if self.value in value.split(self.delimiter):
col[column == val_idx] = 1
parts = value.split(self.delimiter)
if self.onehot:
col[column == val_idx] = int(self.value in parts)
else:
col[column == val_idx] = parts.count(self.value)
return col

def __eq__(self, other):
return self.variable == other.variable \
and self.value == other.value \
and self.delimiter == other.delimiter
and self.delimiter == other.delimiter \
and self.onehot == other.onehot

def __hash__(self):
return hash((self.variable, self.value, self.delimiter))
return hash((self.variable, self.value, self.delimiter, self.onehot))


class OWTextToColumns(OWWidget):
name = "Text to Columns"
class OWSplit(OWWidget):
name = "Split"
description = "Split text or categorical variables into binary indicators"
icon = "icons/TextToColumns.svg"
keywords = ["split"]
icon = "icons/Split.svg"
keywords = ["text to columns", "word encoding", "questionnaire", "survey",
"term", "word presence", "word counts", "categorical encoding",
"indicator variables"]
priority = 700
replaces = ["orangecontrib.prototypes.widgets.owsplit.OWSplit"]

Expand All @@ -106,9 +140,13 @@ class Warning(OWWidget.Warning):
want_main_area = False
resizing_enabled = False

Categorical, Numerical, Counts = range(3)
OutputLabels = ("Categorical (No, Yes)", "Numerical (0, 1)", "Counts")

settingsHandler = DomainContextHandler()
attribute = ContextSetting(None)
delimiter = ContextSetting(";")
output_type = ContextSetting(Categorical)
auto_apply = Setting(True)

def __init__(self):
Expand All @@ -123,8 +161,14 @@ def __init__(self):
model=DomainModel(valid_types=(StringVariable,
DiscreteVariable)))
gui.lineEdit(
variable_select_box, self, "delimiter",
orientation=Qt.Horizontal, callback=self.apply.deferred)
variable_select_box, self, "delimiter", "Delimiter: ",
orientation=Qt.Horizontal, callback=self.apply.deferred,
controlWidth=20).box.layout().addStretch(1)

gui.radioButtonsInBox(
self.controlArea, self, "output_type", self.OutputLabels,
box="Output Values",
callback=self.apply.deferred)

gui.auto_apply(self.buttonsArea, self, commit=self.apply)

Expand All @@ -150,28 +194,44 @@ def apply(self):
self.Outputs.data.send(None)
return
var = self.data.domain[self.attribute]

if var.is_discrete:
values = get_substrings(var.values, self.delimiter)
computer = partial(OneHotDiscrete, var, self.delimiter)
else:
sc = SplitColumn(self.data, var, self.delimiter)
values = sc.new_values
computer = partial(OneHotStrings, sc)
names = get_unique_names(self.data.domain, values, equal_numbers=False)

new_columns = tuple(DiscreteVariable(
name, values=("0", "1"), compute_value=computer(value)
) for value, name in zip(values, names))

values, computer = self._get_compute_value(var)
new_columns = self._get_new_columns(values, computer)
new_domain = Domain(
self.data.domain.attributes + new_columns,
self.data.domain.class_vars, self.data.domain.metas
)
extended_data = self.data.transform(new_domain)
self.Outputs.data.send(extended_data)

def _get_compute_value(self, var):
if var.is_discrete:
values = get_substrings(var.values, self.delimiter)
computer = partial(
DiscreteEncoding,
var, self.delimiter, self.output_type != self.Counts)
else:
if self.output_type == self.Counts:
sc = SplitColumnCounts(self.data, var, self.delimiter)
computer = partial(CountStrings, sc)
else:
sc = SplitColumnOneHot(self.data, var, self.delimiter)
computer = partial(OneHotStrings, sc)
values = sc.new_values
return values, computer

def _get_new_columns(self, values, computer):
names = get_unique_names(self.data.domain, values, equal_numbers=False)
if self.output_type == self.Categorical:
return tuple(
DiscreteVariable(
name, ("No", "Yes"), compute_value=computer(value))
for value, name in zip(values, names))
else:
return tuple(
ContinuousVariable(
name, compute_value=computer(value))
for value, name in zip(values, names))


if __name__ == "__main__": # pragma: no cover
WidgetPreview(OWTextToColumns).run(Table.from_file(
"tests/orange-in-education.tab"))
WidgetPreview(OWSplit).run(Table.from_file("tests/orange-in-education.tab"))
Loading
Loading